diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index f748d498e8f22..64d9c82fa552b 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -117,13 +117,14 @@ class FlightClient; /// DoPut stream. class FlightStreamWriter : public ipc::RecordBatchWriter { public: - explicit FlightStreamWriter(std::unique_ptr&& rpc, + explicit FlightStreamWriter(std::unique_ptr rpc, const FlightDescriptor& descriptor, - const std::shared_ptr& schema) - : rpc_{std::move(rpc)}, - descriptor_{descriptor}, - schema_{schema}, - pool_{default_memory_pool()} {} + const std::shared_ptr& schema, + MemoryPool* pool = default_memory_pool()) + : rpc_(std::move(rpc)), + descriptor_(descriptor), + schema_(schema), + pool_(pool) {} Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override { IpcPayload payload; @@ -154,7 +155,7 @@ class FlightStreamWriter : public ipc::RecordBatchWriter { private: /// \brief Set the gRPC writer backing this Flight stream. /// \param [in] writer the gRPC writer - void set_stream(std::unique_ptr>&& writer) { + void set_stream(std::unique_ptr> writer) { writer_ = std::move(writer); } diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index 61c357e561f1b..730d8b10f63bb 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -100,7 +100,9 @@ class ARROW_EXPORT FlightClient { Status DoGet(const Ticket& ticket, const std::shared_ptr& schema, std::unique_ptr* stream); - /// \brief Upload data to a Flight described by the given descriptor. + /// \brief Upload data to a Flight described by the given + /// descriptor. The caller must call Close() on the returned stream + /// once they are done writing. /// \param[in] descriptor the descriptor of the stream /// \param[in] schema the schema for the data to upload /// \param[out] stream a writer to write record batches to diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index 77f665356f1db..1d5b215c8ead3 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -53,15 +53,15 @@ namespace flight { return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, MESSAGE); \ } -class ARROW_EXPORT FlightMessageReaderImpl : public FlightMessageReader { +class FlightMessageReaderImpl : public FlightMessageReader { public: FlightMessageReaderImpl(const FlightDescriptor& descriptor, std::shared_ptr schema, grpc::ServerReader* reader) - : descriptor_{descriptor}, - schema_{schema}, - reader_{reader}, - stream_finished_{false} {} + : descriptor_(descriptor), + schema_(schema), + reader_(reader), + stream_finished_(false) {} const FlightDescriptor& descriptor() const override { return descriptor_; } diff --git a/cpp/src/arrow/flight/test-integration-client.cc b/cpp/src/arrow/flight/test-integration-client.cc index 7952668bfdbcd..3954bc8523b94 100644 --- a/cpp/src/arrow/flight/test-integration-client.cc +++ b/cpp/src/arrow/flight/test-integration-client.cc @@ -39,6 +39,60 @@ DEFINE_string(host, "localhost", "Server port to connect to"); DEFINE_int32(port, 31337, "Server port to connect to"); DEFINE_string(path, "", "Resource path to request"); +/// \brief Helper to read a RecordBatchReader into a Table. +arrow::Status ReadToTable(std::unique_ptr& reader, + std::shared_ptr* retrieved_data) { + std::vector> retrieved_chunks; + std::shared_ptr chunk; + while (true) { + RETURN_NOT_OK(reader->ReadNext(&chunk)); + if (chunk == nullptr) break; + retrieved_chunks.push_back(chunk); + } + return arrow::Table::FromRecordBatches(reader->schema(), retrieved_chunks, + retrieved_data); +} + +/// \brief Helper to read a JsonReader into a Table. +arrow::Status ReadToTable(std::unique_ptr& reader, + std::shared_ptr* retrieved_data) { + std::vector> retrieved_chunks; + std::shared_ptr chunk; + for (int i = 0; i < reader->num_record_batches(); i++) { + RETURN_NOT_OK(reader->ReadRecordBatch(i, &chunk)); + retrieved_chunks.push_back(chunk); + } + return arrow::Table::FromRecordBatches(reader->schema(), retrieved_chunks, + retrieved_data); +} + +/// \brief Helper to copy a RecordBatchReader to a RecordBatchWriter. +arrow::Status CopyReaderToWriter(std::unique_ptr& reader, + std::unique_ptr& writer) { + while (true) { + std::shared_ptr chunk; + RETURN_NOT_OK(reader->ReadNext(&chunk)); + if (chunk == nullptr) break; + RETURN_NOT_OK(writer->WriteRecordBatch(*chunk)); + } + return writer->Close(); +} + +/// \brief Helper to read a flight into a Table. +arrow::Status ConsumeFlightLocation(const arrow::flight::Location& location, + const arrow::flight::Ticket& ticket, + const std::shared_ptr& schema, + std::shared_ptr* retrieved_data) { + std::unique_ptr read_client; + RETURN_NOT_OK( + arrow::flight::FlightClient::Connect(location.host, location.port, &read_client)); + + std::unique_ptr stream; + RETURN_NOT_OK(read_client->DoGet(ticket, schema, &stream)); + + return ReadToTable(stream, retrieved_data); +} + int main(int argc, char** argv) { gflags::SetUsageMessage("Integration testing client for Flight."); gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -57,21 +111,14 @@ int main(int argc, char** argv) { ABORT_NOT_OK(arrow::ipc::internal::json::JsonReader::Open(arrow::default_memory_pool(), in_file, &reader)); + std::shared_ptr original_data; + ABORT_NOT_OK(ReadToTable(reader, &original_data)); + std::unique_ptr write_stream; ABORT_NOT_OK(client->DoPut(descr, reader->schema(), &write_stream)); - - std::vector> original_chunks; - for (int i = 0; i < reader->num_record_batches(); i++) { - std::shared_ptr batch; - ABORT_NOT_OK(reader->ReadRecordBatch(i, &batch)); - original_chunks.push_back(batch); - ABORT_NOT_OK(write_stream->WriteRecordBatch(*batch)); - } - ABORT_NOT_OK(write_stream->Close()); - - std::shared_ptr original_data; - ABORT_NOT_OK( - arrow::Table::FromRecordBatches(reader->schema(), original_chunks, &original_data)); + std::unique_ptr table_reader( + new arrow::TableBatchReader(*original_data)); + ABORT_NOT_OK(CopyReaderToWriter(table_reader, write_stream)); // 2. Get the ticket for the data. std::unique_ptr info; @@ -97,26 +144,10 @@ int main(int argc, char** argv) { std::cout << "Verifying location " << location.host << ':' << location.port << std::endl; // 3. Download the data from the server. - std::unique_ptr read_client; - ABORT_NOT_OK(arrow::flight::FlightClient::Connect(location.host, location.port, - &read_client)); - - std::unique_ptr stream; - ABORT_NOT_OK(read_client->DoGet(ticket, schema, &stream)); - - std::vector> retrieved_chunks; - std::shared_ptr chunk; - while (true) { - ABORT_NOT_OK(stream->ReadNext(&chunk)); - if (chunk == nullptr) break; - retrieved_chunks.push_back(chunk); - } - - // 4. Validate that the data is equal. std::shared_ptr retrieved_data; - ABORT_NOT_OK( - arrow::Table::FromRecordBatches(schema, retrieved_chunks, &retrieved_data)); + ABORT_NOT_OK(ConsumeFlightLocation(location, ticket, schema, &retrieved_data)); + // 4. Validate that the data is equal. if (!original_data->Equals(*retrieved_data)) { std::cerr << "Data does not match!" << std::endl; return 1; diff --git a/cpp/src/arrow/flight/test-integration-server.cc b/cpp/src/arrow/flight/test-integration-server.cc index c50e52713b33f..7e201a031943d 100644 --- a/cpp/src/arrow/flight/test-integration-server.cc +++ b/cpp/src/arrow/flight/test-integration-server.cc @@ -47,7 +47,7 @@ class FlightIntegrationTestServer : public FlightServerBase { auto data = uploaded_chunks.find(request.path[0]); if (data == uploaded_chunks.end()) { - return Status::KeyError("Could not find flight."); + return Status::KeyError("Could not find flight.", request.path[0]); } auto flight = data->second; @@ -72,7 +72,7 @@ class FlightIntegrationTestServer : public FlightServerBase { std::unique_ptr* data_stream) override { auto data = uploaded_chunks.find(request.ticket); if (data == uploaded_chunks.end()) { - return Status::KeyError("Could not find flight."); + return Status::KeyError("Could not find flight.", request.ticket); } auto flight = data->second;