From 6edf2e2bc9e2f88842b1da31244cbab4fb2aff40 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 1 Feb 2019 16:50:39 -0500 Subject: [PATCH] Introduce FlightPutWriter --- cpp/src/arrow/flight/client.cc | 41 ++++++++++++------- cpp/src/arrow/flight/client.h | 28 +++++++++---- cpp/src/arrow/flight/internal.cc | 3 +- cpp/src/arrow/flight/internal.h | 3 +- cpp/src/arrow/flight/serialization-internal.h | 26 ++++++------ .../arrow/flight/test-integration-client.cc | 10 ++--- 6 files changed, 67 insertions(+), 44 deletions(-) diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 64d9c82fa552b..26cf15bf1db94 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -115,16 +115,13 @@ class FlightClient; /// \brief A RecordBatchWriter implementation that writes to a Flight /// DoPut stream. -class FlightStreamWriter : public ipc::RecordBatchWriter { +class FlightPutWriter::FlightPutWriterImpl : public ipc::RecordBatchWriter { public: - explicit FlightStreamWriter(std::unique_ptr rpc, - const FlightDescriptor& descriptor, - const std::shared_ptr& schema, - MemoryPool* pool = default_memory_pool()) - : rpc_(std::move(rpc)), - descriptor_(descriptor), - schema_(schema), - pool_(pool) {} + explicit FlightPutWriterImpl(std::unique_ptr rpc, + const FlightDescriptor& descriptor, + const std::shared_ptr& schema, + MemoryPool* pool = default_memory_pool()) + : rpc_(std::move(rpc)), descriptor_(descriptor), schema_(schema), pool_(pool) {} Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override { IpcPayload payload; @@ -171,6 +168,22 @@ class FlightStreamWriter : public ipc::RecordBatchWriter { friend class FlightClient; }; +FlightPutWriter::~FlightPutWriter() {} + +FlightPutWriter::FlightPutWriter(std::unique_ptr impl) { + impl_ = std::move(impl); +} + +Status FlightPutWriter::WriteRecordBatch(const RecordBatch& batch, bool allow_64bit) { + return impl_->WriteRecordBatch(batch, allow_64bit); +} + +Status FlightPutWriter::Close() { return impl_->Close(); } + +void FlightPutWriter::set_memory_pool(MemoryPool* pool) { + return impl_->set_memory_pool(pool); +} + class FlightClient::FlightClientImpl { public: Status Connect(const std::string& host, int port) { @@ -277,10 +290,10 @@ class FlightClient::FlightClientImpl { } Status DoPut(const FlightDescriptor& descriptor, const std::shared_ptr& schema, - std::unique_ptr* stream) { + std::unique_ptr* stream) { std::unique_ptr rpc(new ClientRpc); - std::unique_ptr out( - new FlightStreamWriter(std::move(rpc), descriptor, schema)); + std::unique_ptr out( + new FlightPutWriter::FlightPutWriterImpl(std::move(rpc), descriptor, schema)); std::unique_ptr> write_stream( stub_->DoPut(&out->rpc_->context, &out->response)); @@ -302,7 +315,7 @@ class FlightClient::FlightClientImpl { } out->set_stream(std::move(write_stream)); - *stream = std::move(out); + *stream = std::unique_ptr(new FlightPutWriter(std::move(out))); return Status::OK(); } @@ -350,7 +363,7 @@ Status FlightClient::DoGet(const Ticket& ticket, const std::shared_ptr& Status FlightClient::DoPut(const FlightDescriptor& descriptor, const std::shared_ptr& schema, - std::unique_ptr* stream) { + std::unique_ptr* stream) { return impl_->DoPut(descriptor, schema, stream); } diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index 730d8b10f63bb..0ef96c500cfa0 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -24,6 +24,7 @@ #include #include +#include "arrow/ipc/writer.h" #include "arrow/status.h" #include "arrow/util/visibility.h" @@ -35,14 +36,10 @@ class RecordBatch; class RecordBatchReader; class Schema; -namespace ipc { - -class RecordBatchWriter; - -} // namespace ipc - namespace flight { +class FlightPutWriter; + /// \brief Client class for Arrow Flight RPC services (gRPC-based). /// API experimental for now class ARROW_EXPORT FlightClient { @@ -108,7 +105,7 @@ class ARROW_EXPORT FlightClient { /// \param[out] stream a writer to write record batches to /// \return Status Status DoPut(const FlightDescriptor& descriptor, const std::shared_ptr& schema, - std::unique_ptr* stream); + std::unique_ptr* stream); private: FlightClient(); @@ -116,5 +113,22 @@ class ARROW_EXPORT FlightClient { std::unique_ptr impl_; }; +/// \brief An interface to upload record batches to a Flight server +class ARROW_EXPORT FlightPutWriter : public ipc::RecordBatchWriter { + public: + ~FlightPutWriter(); + + Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override; + Status Close() override; + void set_memory_pool(MemoryPool* pool) override; + + private: + class FlightPutWriterImpl; + explicit FlightPutWriter(std::unique_ptr impl); + std::unique_ptr impl_; + + friend class FlightClient; +}; + } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/internal.cc b/cpp/src/arrow/flight/internal.cc index b614dd5b3ffc0..a614450e8d0a0 100644 --- a/cpp/src/arrow/flight/internal.cc +++ b/cpp/src/arrow/flight/internal.cc @@ -133,8 +133,7 @@ void ToProto(const Ticket& ticket, pb::Ticket* pb_ticket) { // FlightData -Status FromProto(const pb::FlightData& pb_data, - FlightDescriptor* descriptor, +Status FromProto(const pb::FlightData& pb_data, FlightDescriptor* descriptor, std::unique_ptr* message) { RETURN_NOT_OK(internal::FromProto(pb_data.flight_descriptor(), descriptor)); const std::string& header = pb_data.data_header(); diff --git a/cpp/src/arrow/flight/internal.h b/cpp/src/arrow/flight/internal.h index a4bafd2693df9..7f9bda138cbb1 100644 --- a/cpp/src/arrow/flight/internal.h +++ b/cpp/src/arrow/flight/internal.h @@ -57,8 +57,7 @@ Status FromProto(const pb::Result& pb_result, Result* result); Status FromProto(const pb::Criteria& pb_criteria, Criteria* criteria); Status FromProto(const pb::Location& pb_location, Location* location); Status FromProto(const pb::Ticket& pb_ticket, Ticket* ticket); -Status FromProto(const pb::FlightData& pb_data, - FlightDescriptor* descriptor, +Status FromProto(const pb::FlightData& pb_data, FlightDescriptor* descriptor, std::unique_ptr* message); Status FromProto(const pb::FlightDescriptor& pb_descr, FlightDescriptor* descr); Status FromProto(const pb::FlightEndpoint& pb_endpoint, FlightEndpoint* endpoint); diff --git a/cpp/src/arrow/flight/serialization-internal.h b/cpp/src/arrow/flight/serialization-internal.h index 4fa7238ddebfd..73d15f6bfaf07 100644 --- a/cpp/src/arrow/flight/serialization-internal.h +++ b/cpp/src/arrow/flight/serialization-internal.h @@ -168,8 +168,7 @@ inline Status FailSerialization(Status status) { inline arrow::Status FailSerialization(arrow::Status status) { if (!status.ok()) { - ARROW_LOG(WARNING) << "Error deserializing Flight message: " - << status.ToString(); + ARROW_LOG(WARNING) << "Error deserializing Flight message: " << status.ToString(); } return status; } @@ -180,9 +179,8 @@ template <> class SerializationTraits { public: static Status Serialize(const FlightData& msg, ByteBuffer** buffer, bool* own_buffer) { - return FailSerialization( - Status(StatusCode::UNIMPLEMENTED, - "internal::FlightData serialization not implemented")); + return FailSerialization(Status( + StatusCode::UNIMPLEMENTED, "internal::FlightData serialization not implemented")); } static Status Deserialize(ByteBuffer* buffer, FlightData* out) { @@ -207,20 +205,20 @@ class SerializationTraits { case pb::FlightData::kFlightDescriptorFieldNumber: { pb::FlightDescriptor pb_descriptor; if (!pb_descriptor.ParseFromCodedStream(&pb_stream)) { - return FailSerialization(Status(StatusCode::INTERNAL, - "Unable to parse FlightDescriptor")); + return FailSerialization( + Status(StatusCode::INTERNAL, "Unable to parse FlightDescriptor")); } } break; case pb::FlightData::kDataHeaderFieldNumber: { if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->metadata)) { - return FailSerialization(Status(StatusCode::INTERNAL, - "Unable to read FlightData metadata")); + return FailSerialization( + Status(StatusCode::INTERNAL, "Unable to read FlightData metadata")); } } break; case pb::FlightData::kDataBodyFieldNumber: { if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->body)) { - return FailSerialization(Status(StatusCode::INTERNAL, - "Unable to read FlightData body")); + return FailSerialization( + Status(StatusCode::INTERNAL, "Unable to read FlightData body")); } } break; default: @@ -278,9 +276,9 @@ class SerializationTraits { // TODO(wesm): messages over 2GB unlikely to be yet supported if (total_size > kInt32Max) { - return FailSerialization(grpc::Status( - grpc::StatusCode::INVALID_ARGUMENT, - "Cannot send record batches exceeding 2GB yet")); + return FailSerialization( + grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "Cannot send record batches exceeding 2GB yet")); } // Allocate slice, assign to output buffer diff --git a/cpp/src/arrow/flight/test-integration-client.cc b/cpp/src/arrow/flight/test-integration-client.cc index 3954bc8523b94..8a6414adf2b19 100644 --- a/cpp/src/arrow/flight/test-integration-client.cc +++ b/cpp/src/arrow/flight/test-integration-client.cc @@ -68,14 +68,14 @@ arrow::Status ReadToTable(std::unique_ptr& reader, - std::unique_ptr& writer) { + arrow::ipc::RecordBatchWriter& writer) { while (true) { std::shared_ptr chunk; RETURN_NOT_OK(reader->ReadNext(&chunk)); if (chunk == nullptr) break; - RETURN_NOT_OK(writer->WriteRecordBatch(*chunk)); + RETURN_NOT_OK(writer.WriteRecordBatch(*chunk)); } - return writer->Close(); + return writer.Close(); } /// \brief Helper to read a flight into a Table. @@ -114,11 +114,11 @@ int main(int argc, char** argv) { std::shared_ptr original_data; ABORT_NOT_OK(ReadToTable(reader, &original_data)); - std::unique_ptr write_stream; + std::unique_ptr write_stream; ABORT_NOT_OK(client->DoPut(descr, reader->schema(), &write_stream)); std::unique_ptr table_reader( new arrow::TableBatchReader(*original_data)); - ABORT_NOT_OK(CopyReaderToWriter(table_reader, write_stream)); + ABORT_NOT_OK(CopyReaderToWriter(table_reader, *write_stream)); // 2. Get the ticket for the data. std::unique_ptr info;