From 846df73050d93abafe8a8849475473e696243faf Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 25 Jan 2019 10:00:27 -0500 Subject: [PATCH 01/20] Implement put in Java Flight integration server --- .../integration/IntegrationTestServer.java | 33 +++++++++++++++++-- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java index 7b45e53a149be..2b78cca93aaef 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java @@ -18,6 +18,7 @@ package org.apache.arrow.flight.example.integration; import java.io.File; +import java.io.FileOutputStream; import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.concurrent.Callable; @@ -39,6 +40,9 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.ArrowFileWriter; import org.apache.arrow.vector.ipc.JsonFileReader; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.cli.CommandLine; @@ -131,7 +135,7 @@ public FlightInfo getFlightInfo(FlightDescriptor descriptor) { Schema schema = reader.start(); return new FlightInfo(schema, descriptor, Collections.singletonList(new FlightEndpoint(new Ticket(path.getBytes()), - new Location("localhost", 31338))), + new Location("localhost", 31338))), 0, 0); } catch (Exception e) { throw new RuntimeException(e); @@ -140,12 +144,35 @@ public FlightInfo getFlightInfo(FlightDescriptor descriptor) { @Override public Callable acceptPut(FlightStream flightStream) { - return null; + return () -> { + if (flightStream.getDescriptor().isCommand()) { + throw new UnsupportedOperationException("Commands not supported."); + } + if (flightStream.getDescriptor().getPath().size() < 1) { + throw new IllegalArgumentException("Must provide a path."); + } + String path = flightStream.getDescriptor().getPath().get(0); + File outputFile = new File(path); + if (!outputFile.createNewFile()) { + throw new IllegalStateException("File already exists."); + } + try (VectorSchemaRoot root = flightStream.getRoot(); + FileOutputStream fileOutputStream = new FileOutputStream(outputFile); + ArrowFileWriter writer = new ArrowFileWriter(root, new DictionaryProvider.MapDictionaryProvider(), + fileOutputStream.getChannel())) { + writer.start(); + while (flightStream.next()) { + writer.writeBatch(); + } + writer.end(); + } + return Flight.PutResult.getDefaultInstance(); + }; } @Override public Result doAction(Action action) { - return null; + throw new UnsupportedOperationException("No actions implemented."); } @Override From b3ac01ab54f76e7633cd6432f0f1d806a29b054c Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 28 Jan 2019 11:49:21 -0500 Subject: [PATCH 02/20] Align RecordBatch on client side in Flight DoPut --- .../src/main/java/org/apache/arrow/flight/FlightClient.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java index ad7c7e28da242..74e73d1f83623 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java @@ -127,7 +127,7 @@ public ClientStreamListener startPut(FlightDescriptor descriptor, VectorSchemaRo // send the schema to start. ArrowMessage message = new ArrowMessage(descriptor.toProtocol(), root.getSchema()); observer.onNext(message); - return new PutObserver(new VectorUnloader(root, true, false), observer, resultObserver.getFuture()); + return new PutObserver(new VectorUnloader(root, true, true), observer, resultObserver.getFuture()); } public FlightInfo getInfo(FlightDescriptor descriptor) { From a11a5acfda6827cf191eb927164905451943ce3a Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 28 Jan 2019 11:49:46 -0500 Subject: [PATCH 03/20] Don't hang in Flight DoPut if server sends exception --- .../src/main/java/org/apache/arrow/flight/FlightClient.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java index 74e73d1f83623..7fcfd52008f98 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java @@ -211,7 +211,8 @@ public PutObserver(VectorUnloader unloader, ClientCallStreamObserver Date: Mon, 28 Jan 2019 11:50:58 -0500 Subject: [PATCH 04/20] Fix FromProto for FlightDescriptor --- cpp/src/arrow/flight/internal.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/flight/internal.cc b/cpp/src/arrow/flight/internal.cc index b4c6b2addcc11..629796ea36850 100644 --- a/cpp/src/arrow/flight/internal.cc +++ b/cpp/src/arrow/flight/internal.cc @@ -156,7 +156,7 @@ Status FromProto(const pb::FlightDescriptor& pb_descriptor, FlightDescriptor* descriptor) { if (pb_descriptor.type() == pb::FlightDescriptor::PATH) { descriptor->type = FlightDescriptor::PATH; - descriptor->path.resize(pb_descriptor.path_size()); + descriptor->path.reserve(pb_descriptor.path_size()); for (int i = 0; i < pb_descriptor.path_size(); ++i) { descriptor->path.emplace_back(pb_descriptor.path(i)); } From 905ef38fa3c90c520642d0ae14bf3e6d7b3aa608 Mon Sep 17 00:00:00 2001 From: David Li Date: Sat, 26 Jan 2019 13:45:15 -0500 Subject: [PATCH 05/20] Implement C++ Flight DoPut --- cpp/src/arrow/flight/CMakeLists.txt | 1 + cpp/src/arrow/flight/client.cc | 262 ++++++--------- cpp/src/arrow/flight/client.h | 16 +- .../arrow/flight/serialization-internal.cc | 33 ++ cpp/src/arrow/flight/serialization-internal.h | 316 ++++++++++++++++++ cpp/src/arrow/flight/server.cc | 213 +++++------- cpp/src/arrow/flight/server.h | 13 +- .../arrow/flight/test-integration-client.cc | 54 ++- .../arrow/flight/test-integration-server.cc | 102 +++--- cpp/src/arrow/flight/types.h | 6 - integration/integration_test.py | 31 +- .../arrow/flight/example/InMemoryStore.java | 1 + .../integration/IntegrationTestClient.java | 64 ++-- .../integration/IntegrationTestServer.java | 155 ++------- 14 files changed, 720 insertions(+), 547 deletions(-) create mode 100644 cpp/src/arrow/flight/serialization-internal.cc create mode 100644 cpp/src/arrow/flight/serialization-internal.h diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index b8b4d8d336365..1109e0eb9da70 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -69,6 +69,7 @@ set(ARROW_FLIGHT_SRCS Flight.pb.cc Flight.grpc.pb.cc internal.cc + serialization-internal.cc server.cc types.cc ) diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index e25c1875d669f..99f88d08a843e 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -22,12 +22,12 @@ #include #include -#include "google/protobuf/io/coded_stream.h" -#include "google/protobuf/wire_format_lite.h" -#include "grpc/byte_buffer_reader.h" #include "grpcpp/grpcpp.h" +#include "arrow/ipc/dictionary.h" +#include "arrow/ipc/metadata-internal.h" #include "arrow/ipc/reader.h" +#include "arrow/ipc/writer.h" #include "arrow/record_batch.h" #include "arrow/status.h" #include "arrow/type.h" @@ -36,161 +36,11 @@ #include "arrow/flight/Flight.grpc.pb.h" #include "arrow/flight/Flight.pb.h" #include "arrow/flight/internal.h" +#include "arrow/flight/serialization-internal.h" -namespace pb = arrow::flight::protocol; - -namespace arrow { -namespace flight { - -/// Internal, not user-visible type used for memory-efficient reads from gRPC -/// stream -struct FlightData { - /// Used only for puts, may be null - std::unique_ptr descriptor; - - /// Non-length-prefixed Message header as described in format/Message.fbs - std::shared_ptr metadata; - - /// Message body - std::shared_ptr body; -}; - -} // namespace flight -} // namespace arrow - -namespace grpc { - -// Customizations to gRPC for more efficient deserialization of FlightData - -using google::protobuf::internal::WireFormatLite; -using google::protobuf::io::CodedInputStream; - -using arrow::flight::FlightData; - -bool ReadBytesZeroCopy(const std::shared_ptr& source_data, - CodedInputStream* input, std::shared_ptr* out) { - uint32_t length; - if (!input->ReadVarint32(&length)) { - return false; - } - *out = arrow::SliceBuffer(source_data, input->CurrentPosition(), - static_cast(length)); - return input->Skip(static_cast(length)); -} +using arrow::ipc::internal::IpcPayload; -// Internal wrapper for gRPC ByteBuffer so its memory can be exposed to Arrow -// consumers with zero-copy -class GrpcBuffer : public arrow::MutableBuffer { - public: - GrpcBuffer(grpc_slice slice, bool incref) - : MutableBuffer(GRPC_SLICE_START_PTR(slice), - static_cast(GRPC_SLICE_LENGTH(slice))), - slice_(incref ? grpc_slice_ref(slice) : slice) {} - - ~GrpcBuffer() override { - // Decref slice - grpc_slice_unref(slice_); - } - - static arrow::Status Wrap(ByteBuffer* cpp_buf, std::shared_ptr* out) { - // These types are guaranteed by static assertions in gRPC to have the same - // in-memory representation - - auto buffer = *reinterpret_cast(cpp_buf); - - // This part below is based on the Flatbuffers gRPC SerializationTraits in - // flatbuffers/grpc.h - - // Check if this is a single uncompressed slice. - if ((buffer->type == GRPC_BB_RAW) && - (buffer->data.raw.compression == GRPC_COMPRESS_NONE) && - (buffer->data.raw.slice_buffer.count == 1)) { - // If it is, then we can reference the `grpc_slice` directly. - grpc_slice slice = buffer->data.raw.slice_buffer.slices[0]; - - // Increment reference count so this memory remains valid - *out = std::make_shared(slice, true); - } else { - // Otherwise, we need to use `grpc_byte_buffer_reader_readall` to read - // `buffer` into a single contiguous `grpc_slice`. The gRPC reader gives - // us back a new slice with the refcount already incremented. - grpc_byte_buffer_reader reader; - if (!grpc_byte_buffer_reader_init(&reader, buffer)) { - return arrow::Status::IOError("Internal gRPC error reading from ByteBuffer"); - } - grpc_slice slice = grpc_byte_buffer_reader_readall(&reader); - grpc_byte_buffer_reader_destroy(&reader); - - // Steal the slice reference - *out = std::make_shared(slice, false); - } - - return arrow::Status::OK(); - } - - private: - grpc_slice slice_; -}; - -// Read internal::FlightData from grpc::ByteBuffer containing FlightData -// protobuf without copying -template <> -class SerializationTraits { - public: - static Status Serialize(const FlightData& msg, ByteBuffer** buffer, bool* own_buffer) { - return Status(StatusCode::UNIMPLEMENTED, - "internal::FlightData serialization not implemented"); - } - - static Status Deserialize(ByteBuffer* buffer, FlightData* out) { - if (!buffer) { - return Status(StatusCode::INTERNAL, "No payload"); - } - - std::shared_ptr wrapped_buffer; - GRPC_RETURN_NOT_OK(GrpcBuffer::Wrap(buffer, &wrapped_buffer)); - - auto buffer_length = static_cast(wrapped_buffer->size()); - CodedInputStream pb_stream(wrapped_buffer->data(), buffer_length); - - // TODO(wesm): The 2-parameter version of this function is deprecated - pb_stream.SetTotalBytesLimit(buffer_length, -1 /* no threshold */); - - // This is the bytes remaining when using CodedInputStream like this - while (pb_stream.BytesUntilTotalBytesLimit()) { - const uint32_t tag = pb_stream.ReadTag(); - const int field_number = WireFormatLite::GetTagFieldNumber(tag); - switch (field_number) { - case pb::FlightData::kFlightDescriptorFieldNumber: { - pb::FlightDescriptor pb_descriptor; - if (!pb_descriptor.ParseFromCodedStream(&pb_stream)) { - return Status(StatusCode::INTERNAL, "Unable to parse FlightDescriptor"); - } - } break; - case pb::FlightData::kDataHeaderFieldNumber: { - if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->metadata)) { - return Status(StatusCode::INTERNAL, "Unable to read FlightData metadata"); - } - } break; - case pb::FlightData::kDataBodyFieldNumber: { - if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->body)) { - return Status(StatusCode::INTERNAL, "Unable to read FlightData body"); - } - } break; - default: - DCHECK(false) << "cannot happen"; - } - } - buffer->Clear(); - - // TODO(wesm): Where and when should we verify that the FlightData is not - // malformed or missing components? - - return Status::OK; - } -}; - -} // namespace grpc +namespace pb = arrow::flight::protocol; namespace arrow { namespace flight { @@ -227,7 +77,9 @@ class FlightStreamReader : public RecordBatchReader { // For customizing read path for better memory/serialization efficiency auto custom_reader = reinterpret_cast*>(stream_.get()); - if (custom_reader->Read(&data)) { + // Explicitly specify the override to invoke - otherwise compiler + // may invoke through vtable (not updated by reinterpret_cast) + if (custom_reader->grpc::ClientReader::Read(&data)) { std::unique_ptr message; // Validate IPC message @@ -259,6 +111,65 @@ class FlightStreamReader : public RecordBatchReader { std::unique_ptr> stream_; }; +class FlightClient; + +/// \brief A RecordBatchWriter implementation that writes to a Flight +/// DoPut stream. +class FlightStreamWriter : public ipc::RecordBatchWriter { + public: + explicit FlightStreamWriter(std::unique_ptr&& rpc, + const FlightDescriptor& descriptor, + const std::shared_ptr& schema) + : rpc_{std::move(rpc)}, + descriptor_{descriptor}, + schema_{schema}, + pool_{default_memory_pool()} {}; + + Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override { + IpcPayload payload; + RETURN_NOT_OK(ipc::internal::GetRecordBatchPayload(batch, pool_, &payload)); + auto custom_writer = reinterpret_cast*>(writer_.get()); + // Explicitly specify the override to invoke - otherwise compiler + // may invoke through vtable (not updated by reinterpret_cast) + if (!custom_writer->grpc::ClientWriter::Write(payload, + grpc::WriteOptions())) { + // Stream ended? + return Status::UnknownError("Could not write record batch to stream"); + } + return Status::OK(); + } + + Status Close() override { + bool finished_writes = writer_->WritesDone(); + RETURN_NOT_OK(internal::FromGrpcStatus(writer_->Finish())); + if (!finished_writes) { + return Status::UnknownError( + "Could not finish writing record batches before closing"); + } + return Status::OK(); + } + + void set_memory_pool(MemoryPool* pool) override { pool_ = pool; } + + private: + /// \brief Set the gRPC writer backing this Flight stream. + /// \param [in] writer the gRPC writer + void set_stream(std::unique_ptr>&& writer) { + writer_ = std::move(writer); + } + + // TODO: there isn't a way to access this as a user. + protocol::PutResult response; + std::unique_ptr rpc_; + FlightDescriptor descriptor_; + std::shared_ptr schema_; + std::unique_ptr> writer_; + MemoryPool* pool_; + + // We need to reference some fields + friend class FlightClient; +}; + class FlightClient::FlightClientImpl { public: Status Connect(const std::string& host, int port) { @@ -364,8 +275,34 @@ class FlightClient::FlightClientImpl { return Status::OK(); } - Status DoPut(std::unique_ptr* stream) { - return Status::NotImplemented("DoPut"); + Status DoPut(const FlightDescriptor& descriptor, const std::shared_ptr& schema, + std::unique_ptr* stream) { + std::unique_ptr rpc(new ClientRpc); + std::unique_ptr out( + new FlightStreamWriter(std::move(rpc), descriptor, schema)); + std::unique_ptr> write_stream( + stub_->DoPut(&out->rpc_->context, &out->response)); + + // First write the descriptor and schema to the stream. + pb::FlightData descriptor_message; + RETURN_NOT_OK( + internal::ToProto(descriptor, descriptor_message.mutable_flight_descriptor())); + + std::shared_ptr header_buf; + RETURN_NOT_OK(Buffer::FromString("", &header_buf)); + ipc::DictionaryMemo dictionary_memo; + RETURN_NOT_OK(ipc::SerializeSchema(*schema, out->pool_, &header_buf)); + RETURN_NOT_OK( + ipc::internal::WriteSchemaMessage(*schema, &dictionary_memo, &header_buf)); + descriptor_message.set_data_header(header_buf->ToString()); + + if (!write_stream->Write(descriptor_message, grpc::WriteOptions())) { + return Status::UnknownError("Could not write initial message to stream"); + } + + out->set_stream(std::move(write_stream)); + *stream = std::move(out); + return Status::OK(); } private: @@ -410,9 +347,10 @@ Status FlightClient::DoGet(const Ticket& ticket, const std::shared_ptr& return impl_->DoGet(ticket, schema, stream); } -Status FlightClient::DoPut(const Schema& schema, - std::unique_ptr* stream) { - return Status::NotImplemented("DoPut"); +Status FlightClient::DoPut(const FlightDescriptor& descriptor, + const std::shared_ptr& schema, + std::unique_ptr* stream) { + return impl_->DoPut(descriptor, schema, stream); } } // namespace flight diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index 53bb1755b2995..e548f7c76e848 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -35,6 +35,12 @@ class RecordBatch; class RecordBatchReader; class Schema; +namespace ipc { + +class RecordBatchWriter; + +} // namespace ipc + namespace flight { /// \brief Client class for Arrow Flight RPC services (gRPC-based). @@ -94,12 +100,12 @@ class ARROW_EXPORT FlightClient { Status DoGet(const Ticket& ticket, const std::shared_ptr& schema, std::unique_ptr* stream); - /// \brief Initiate DoPut RPC, returns FlightPutWriter interface to - /// write. Not yet implemented - /// \param[in] schema the schema of the stream data - /// \param[out] stream the created stream to write record batches to + /// \brief Upload data to a Flight described by the given descriptor. + /// \param[in] descriptor the descriptor of the stream + /// \param[out] stream a writer to write record batches to /// \return Status - Status DoPut(const Schema& schema, std::unique_ptr* stream); + Status DoPut(const FlightDescriptor& descriptor, const std::shared_ptr& schema, + std::unique_ptr* stream); private: FlightClient(); diff --git a/cpp/src/arrow/flight/serialization-internal.cc b/cpp/src/arrow/flight/serialization-internal.cc new file mode 100644 index 0000000000000..f3dbba47255f4 --- /dev/null +++ b/cpp/src/arrow/flight/serialization-internal.cc @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/serialization-internal.h" + +namespace grpc { + +bool ReadBytesZeroCopy(const std::shared_ptr& source_data, + CodedInputStream* input, std::shared_ptr* out) { + uint32_t length; + if (!input->ReadVarint32(&length)) { + return false; + } + *out = arrow::SliceBuffer(source_data, input->CurrentPosition(), + static_cast(length)); + return input->Skip(static_cast(length)); +} + +} // namespace grpc diff --git a/cpp/src/arrow/flight/serialization-internal.h b/cpp/src/arrow/flight/serialization-internal.h new file mode 100644 index 0000000000000..e412e247012f1 --- /dev/null +++ b/cpp/src/arrow/flight/serialization-internal.h @@ -0,0 +1,316 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// (De)serialization utilities that hook into gRPC, efficiently +// handling Arrow-encoded data in a gRPC call. + +#pragma once + +#include +#include + +#include "google/protobuf/io/coded_stream.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/wire_format_lite.h" +#include "grpc/byte_buffer_reader.h" +#include "grpcpp/grpcpp.h" + +#include "arrow/ipc/writer.h" +#include "arrow/record_batch.h" +#include "arrow/status.h" +#include "arrow/util/logging.h" + +#include "arrow/flight/Flight.grpc.pb.h" +#include "arrow/flight/Flight.pb.h" +#include "arrow/flight/internal.h" +#include "arrow/flight/types.h" + +namespace pb = arrow::flight::protocol; + +using arrow::ipc::internal::IpcPayload; + +constexpr int64_t kInt32Max = std::numeric_limits::max(); + +namespace arrow { +namespace flight { + +/// Internal, not user-visible type used for memory-efficient reads from gRPC +/// stream +struct FlightData { + /// Used only for puts, may be null + std::unique_ptr descriptor; + + /// Non-length-prefixed Message header as described in format/Message.fbs + std::shared_ptr metadata; + + /// Message body + std::shared_ptr body; +}; + +} // namespace flight +} // namespace arrow + +namespace grpc { + +using google::protobuf::internal::WireFormatLite; +using google::protobuf::io::CodedInputStream; +using google::protobuf::io::CodedOutputStream; + +using arrow::flight::FlightData; + +// More efficient writing of FlightData to gRPC output buffer +// Implementation of ZeroCopyOutputStream that writes to a fixed-size buffer +class FixedSizeProtoWriter : public ::google::protobuf::io::ZeroCopyOutputStream { + public: + explicit FixedSizeProtoWriter(grpc_slice slice) + : slice_(slice), + bytes_written_(0), + total_size_(static_cast(GRPC_SLICE_LENGTH(slice))) {} + + bool Next(void** data, int* size) override { + // Consume the whole slice + *data = GRPC_SLICE_START_PTR(slice_) + bytes_written_; + *size = total_size_ - bytes_written_; + bytes_written_ = total_size_; + return true; + } + + void BackUp(int count) override { bytes_written_ -= count; } + + int64_t ByteCount() const override { return bytes_written_; } + + private: + grpc_slice slice_; + int bytes_written_; + int total_size_; +}; + +bool ReadBytesZeroCopy(const std::shared_ptr& source_data, + CodedInputStream* input, std::shared_ptr* out); + +// Internal wrapper for gRPC ByteBuffer so its memory can be exposed to Arrow +// consumers with zero-copy +class GrpcBuffer : public arrow::MutableBuffer { + public: + GrpcBuffer(grpc_slice slice, bool incref) + : MutableBuffer(GRPC_SLICE_START_PTR(slice), + static_cast(GRPC_SLICE_LENGTH(slice))), + slice_(incref ? grpc_slice_ref(slice) : slice) {} + + ~GrpcBuffer() override { + // Decref slice + grpc_slice_unref(slice_); + } + + static arrow::Status Wrap(ByteBuffer* cpp_buf, std::shared_ptr* out) { + // These types are guaranteed by static assertions in gRPC to have the same + // in-memory representation + + auto buffer = *reinterpret_cast(cpp_buf); + + // This part below is based on the Flatbuffers gRPC SerializationTraits in + // flatbuffers/grpc.h + + // Check if this is a single uncompressed slice. + if ((buffer->type == GRPC_BB_RAW) && + (buffer->data.raw.compression == GRPC_COMPRESS_NONE) && + (buffer->data.raw.slice_buffer.count == 1)) { + // If it is, then we can reference the `grpc_slice` directly. + grpc_slice slice = buffer->data.raw.slice_buffer.slices[0]; + + // Increment reference count so this memory remains valid + *out = std::make_shared(slice, true); + } else { + // Otherwise, we need to use `grpc_byte_buffer_reader_readall` to read + // `buffer` into a single contiguous `grpc_slice`. The gRPC reader gives + // us back a new slice with the refcount already incremented. + grpc_byte_buffer_reader reader; + if (!grpc_byte_buffer_reader_init(&reader, buffer)) { + return arrow::Status::IOError("Internal gRPC error reading from ByteBuffer"); + } + grpc_slice slice = grpc_byte_buffer_reader_readall(&reader); + grpc_byte_buffer_reader_destroy(&reader); + + // Steal the slice reference + *out = std::make_shared(slice, false); + } + + return arrow::Status::OK(); + } + + private: + grpc_slice slice_; +}; + +// Read internal::FlightData from grpc::ByteBuffer containing FlightData +// protobuf without copying +template <> +class SerializationTraits { + public: + static Status Serialize(const FlightData& msg, ByteBuffer** buffer, bool* own_buffer) { + return Status(StatusCode::UNIMPLEMENTED, + "internal::FlightData serialization not implemented"); + } + + static Status Deserialize(ByteBuffer* buffer, FlightData* out) { + if (!buffer) { + return Status(StatusCode::INTERNAL, "No payload"); + } + + std::shared_ptr wrapped_buffer; + GRPC_RETURN_NOT_OK(GrpcBuffer::Wrap(buffer, &wrapped_buffer)); + + auto buffer_length = static_cast(wrapped_buffer->size()); + CodedInputStream pb_stream(wrapped_buffer->data(), buffer_length); + + // TODO(wesm): The 2-parameter version of this function is deprecated + pb_stream.SetTotalBytesLimit(buffer_length, -1 /* no threshold */); + + // This is the bytes remaining when using CodedInputStream like this + while (pb_stream.BytesUntilTotalBytesLimit()) { + const uint32_t tag = pb_stream.ReadTag(); + const int field_number = WireFormatLite::GetTagFieldNumber(tag); + switch (field_number) { + case pb::FlightData::kFlightDescriptorFieldNumber: { + pb::FlightDescriptor pb_descriptor; + if (!pb_descriptor.ParseFromCodedStream(&pb_stream)) { + return Status(StatusCode::INTERNAL, "Unable to parse FlightDescriptor"); + } + } break; + case pb::FlightData::kDataHeaderFieldNumber: { + if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->metadata)) { + return Status(StatusCode::INTERNAL, "Unable to read FlightData metadata"); + } + } break; + case pb::FlightData::kDataBodyFieldNumber: { + if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->body)) { + return Status(StatusCode::INTERNAL, "Unable to read FlightData body"); + } + } break; + default: + DCHECK(false) << "cannot happen"; + } + } + buffer->Clear(); + + // TODO(wesm): Where and when should we verify that the FlightData is not + // malformed or missing components? + + return Status::OK; + } +}; + +// Write FlightData to a grpc::ByteBuffer without extra copying +template <> +class SerializationTraits { + public: + static grpc::Status Deserialize(ByteBuffer* buffer, IpcPayload* out) { + return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, + "IpcPayload deserialization not implemented"); + } + + static grpc::Status Serialize(const IpcPayload& msg, ByteBuffer* out, + bool* own_buffer) { + size_t total_size = 0; + + DCHECK_LT(msg.metadata->size(), kInt32Max); + const int32_t metadata_size = static_cast(msg.metadata->size()); + + // 1 byte for metadata tag + total_size += 1 + WireFormatLite::LengthDelimitedSize(metadata_size); + + int64_t body_size = 0; + for (const auto& buffer : msg.body_buffers) { + // Buffer may be null when the row length is zero, or when all + // entries are invalid. + if (!buffer) continue; + + body_size += buffer->size(); + + const int64_t remainder = buffer->size() % 8; + if (remainder) { + body_size += 8 - remainder; + } + } + + // 2 bytes for body tag + // Only written when there are body buffers + if (msg.body_length > 0) { + total_size += + 2 + WireFormatLite::LengthDelimitedSize(static_cast(body_size)); + } + + // TODO(wesm): messages over 2GB unlikely to be yet supported + if (total_size > kInt32Max) { + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "Cannot send record batches exceeding 2GB yet"); + } + + // Allocate slice, assign to output buffer + grpc::Slice slice(total_size); + + // XXX(wesm): for debugging + // std::cout << "Writing record batch with total size " << total_size << std::endl; + + FixedSizeProtoWriter writer(*reinterpret_cast(&slice)); + CodedOutputStream pb_stream(&writer); + + // Write header + WireFormatLite::WriteTag(pb::FlightData::kDataHeaderFieldNumber, + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &pb_stream); + pb_stream.WriteVarint32(metadata_size); + pb_stream.WriteRawMaybeAliased(msg.metadata->data(), + static_cast(msg.metadata->size())); + + // Don't write tag if there are no body buffers + if (msg.body_length > 0) { + // Write body + WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber, + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &pb_stream); + pb_stream.WriteVarint32(static_cast(body_size)); + + constexpr uint8_t kPaddingBytes[8] = {0}; + + for (const auto& buffer : msg.body_buffers) { + // Buffer may be null when the row length is zero, or when all + // entries are invalid. + if (!buffer) continue; + + pb_stream.WriteRawMaybeAliased(buffer->data(), static_cast(buffer->size())); + + // Write padding if not multiple of 8 + const int remainder = static_cast(buffer->size() % 8); + if (remainder) { + pb_stream.WriteRawMaybeAliased(kPaddingBytes, 8 - remainder); + } + } + } + + DCHECK_EQ(static_cast(total_size), pb_stream.ByteCount()); + + // Hand off the slice to the returned ByteBuffer + grpc::ByteBuffer tmp(&slice, 1); + out->Swap(&tmp); + *own_buffer = true; + return grpc::Status::OK; + } +}; + +template class grpc::ClientWriter; +template class grpc::ClientReader; + +} // namespace grpc diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index 018c079501f2f..77f665356f1db 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -18,15 +18,12 @@ #include "arrow/flight/server.h" #include -#include #include #include -#include "google/protobuf/io/coded_stream.h" -#include "google/protobuf/io/zero_copy_stream.h" -#include "google/protobuf/wire_format_lite.h" #include "grpcpp/grpcpp.h" +#include "arrow/ipc/reader.h" #include "arrow/ipc/writer.h" #include "arrow/record_batch.h" #include "arrow/status.h" @@ -35,6 +32,7 @@ #include "arrow/flight/Flight.grpc.pb.h" #include "arrow/flight/Flight.pb.h" #include "arrow/flight/internal.h" +#include "arrow/flight/serialization-internal.h" #include "arrow/flight/types.h" using FlightService = arrow::flight::protocol::FlightService; @@ -47,145 +45,63 @@ using ServerWriter = grpc::ServerWriter; namespace pb = arrow::flight::protocol; -constexpr int64_t kInt32Max = std::numeric_limits::max(); - -namespace grpc { - -using google::protobuf::internal::WireFormatLite; -using google::protobuf::io::CodedOutputStream; +namespace arrow { +namespace flight { -// More efficient writing of FlightData to gRPC output buffer -// Implementation of ZeroCopyOutputStream that writes to a fixed-size buffer -class FixedSizeProtoWriter : public ::google::protobuf::io::ZeroCopyOutputStream { - public: - explicit FixedSizeProtoWriter(grpc_slice slice) - : slice_(slice), - bytes_written_(0), - total_size_(static_cast(GRPC_SLICE_LENGTH(slice))) {} - - bool Next(void** data, int* size) override { - // Consume the whole slice - *data = GRPC_SLICE_START_PTR(slice_) + bytes_written_; - *size = total_size_ - bytes_written_; - bytes_written_ = total_size_; - return true; +#define CHECK_ARG_NOT_NULL(VAL, MESSAGE) \ + if (VAL == nullptr) { \ + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, MESSAGE); \ } - void BackUp(int count) override { bytes_written_ -= count; } - - int64_t ByteCount() const override { return bytes_written_; } - - private: - grpc_slice slice_; - int bytes_written_; - int total_size_; -}; - -// Write FlightData to a grpc::ByteBuffer without extra copying -template <> -class SerializationTraits { +class ARROW_EXPORT FlightMessageReaderImpl : public FlightMessageReader { public: - static grpc::Status Deserialize(ByteBuffer* buffer, IpcPayload* out) { - return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, - "IpcPayload deserialization not implemented"); - } - - static grpc::Status Serialize(const IpcPayload& msg, ByteBuffer* out, - bool* own_buffer) { - size_t total_size = 0; - - DCHECK_LT(msg.metadata->size(), kInt32Max); - const int32_t metadata_size = static_cast(msg.metadata->size()); - - // 1 byte for metadata tag - total_size += 1 + WireFormatLite::LengthDelimitedSize(metadata_size); - - int64_t body_size = 0; - for (const auto& buffer : msg.body_buffers) { - // Buffer may be null when the row length is zero, or when all - // entries are invalid. - if (!buffer) continue; - - body_size += buffer->size(); - - const int64_t remainder = buffer->size() % 8; - if (remainder) { - body_size += 8 - remainder; - } - } - - // 2 bytes for body tag - // Only written when there are body buffers - if (msg.body_length > 0) { - total_size += - 2 + WireFormatLite::LengthDelimitedSize(static_cast(body_size)); - } - - // TODO(wesm): messages over 2GB unlikely to be yet supported - if (total_size > kInt32Max) { - return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, - "Cannot send record batches exceeding 2GB yet"); + FlightMessageReaderImpl(const FlightDescriptor& descriptor, + std::shared_ptr schema, + grpc::ServerReader* reader) + : descriptor_{descriptor}, + schema_{schema}, + reader_{reader}, + stream_finished_{false} {} + + const FlightDescriptor& descriptor() const override { return descriptor_; } + + std::shared_ptr schema() const override { return schema_; } + + Status ReadNext(std::shared_ptr* out) override { + if (stream_finished_) { + *out = nullptr; + return Status::OK(); } - // Allocate slice, assign to output buffer - grpc::Slice slice(total_size); + auto custom_reader = reinterpret_cast*>(reader_); - // XXX(wesm): for debugging - // std::cout << "Writing record batch with total size " << total_size << std::endl; + FlightData data; + // Explicitly specify the override to invoke - otherwise compiler + // may invoke through vtable (not updated by reinterpret_cast) + if (custom_reader->grpc::ServerReader::Read(&data)) { + std::unique_ptr message; - FixedSizeProtoWriter writer(*reinterpret_cast(&slice)); - CodedOutputStream pb_stream(&writer); - - // Write header - WireFormatLite::WriteTag(pb::FlightData::kDataHeaderFieldNumber, - WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &pb_stream); - pb_stream.WriteVarint32(metadata_size); - pb_stream.WriteRawMaybeAliased(msg.metadata->data(), - static_cast(msg.metadata->size())); - - // Don't write tag if there are no body buffers - if (msg.body_length > 0) { - // Write body - WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber, - WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &pb_stream); - pb_stream.WriteVarint32(static_cast(body_size)); - - constexpr uint8_t kPaddingBytes[8] = {0}; - - for (const auto& buffer : msg.body_buffers) { - // Buffer may be null when the row length is zero, or when all - // entries are invalid. - if (!buffer) continue; - - pb_stream.WriteRawMaybeAliased(buffer->data(), static_cast(buffer->size())); - - // Write padding if not multiple of 8 - const int remainder = static_cast(buffer->size() % 8); - if (remainder) { - pb_stream.WriteRawMaybeAliased(kPaddingBytes, 8 - remainder); - } + // Validate IPC message + RETURN_NOT_OK(ipc::Message::Open(data.metadata, data.body, &message)); + if (message->type() == ipc::Message::Type::RECORD_BATCH) { + return ipc::ReadRecordBatch(*message, schema_, out); + } else { + return Status(StatusCode::Invalid, "Unrecognized message in Flight stream"); } + } else { + // Stream is completed + stream_finished_ = true; + *out = nullptr; + return Status::OK(); } - - DCHECK_EQ(static_cast(total_size), pb_stream.ByteCount()); - - // Hand off the slice to the returned ByteBuffer - grpc::ByteBuffer tmp(&slice, 1); - out->Swap(&tmp); - *own_buffer = true; - return grpc::Status::OK; } -}; -} // namespace grpc - -namespace arrow { -namespace flight { - -#define CHECK_ARG_NOT_NULL(VAL, MESSAGE) \ - if (VAL == nullptr) { \ - return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, MESSAGE); \ - } + private: + FlightDescriptor descriptor_; + std::shared_ptr schema_; + grpc::ServerReader* reader_; + bool stream_finished_; +}; // This class glues an implementation of FlightServerBase together with the // gRPC service definition, so the latter is not exposed in the public API @@ -293,7 +209,36 @@ class FlightServiceImpl : public FlightService::Service { grpc::Status DoPut(ServerContext* context, grpc::ServerReader* reader, pb::PutResult* response) { - return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, ""); + // Get metadata + pb::FlightData data; + if (reader->Read(&data)) { + FlightDescriptor descriptor; + std::unique_ptr message; + std::shared_ptr schema; + GRPC_RETURN_NOT_OK(internal::FromProto(data.flight_descriptor(), &descriptor)); + std::shared_ptr header_buf; + std::shared_ptr body_buf; + GRPC_RETURN_NOT_OK(Buffer::FromString(data.data_header(), &header_buf)); + GRPC_RETURN_NOT_OK(Buffer::FromString(data.data_body(), &body_buf)); + GRPC_RETURN_NOT_OK(ipc::Message::Open(header_buf, body_buf, &message)); + + if (!message || message->type() != ipc::Message::Type::SCHEMA) { + return internal::ToGrpcStatus( + Status(StatusCode::Invalid, "DoPut must start with schema/descriptor")); + } else { + GRPC_RETURN_NOT_OK(ipc::ReadSchema(*message, &schema)); + + auto message_reader = std::unique_ptr( + new FlightMessageReaderImpl(descriptor, schema, reader)); + return internal::ToGrpcStatus(server_->DoPut(std::move(message_reader))); + } + } else { + // TODO(lihalite): gRPC doesn't let us distinguish between no + // message sent, and message failed to deserialize. IMO, we + // should add logging around the Status returns in + // serialization-internal.h to make debugging such cases easier. + return grpc::Status::OK; + } } grpc::Status ListActions(ServerContext* context, const pb::Empty* request, @@ -376,6 +321,10 @@ Status FlightServerBase::DoGet(const Ticket& request, return Status::NotImplemented("NYI"); } +Status FlightServerBase::DoPut(std::unique_ptr reader) { + return Status::NotImplemented("NYI"); +} + Status FlightServerBase::DoAction(const Action& action, std::unique_ptr* result) { return Status::NotImplemented("NYI"); diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h index b3b8239132b7a..f975b8619cd48 100644 --- a/cpp/src/arrow/flight/server.h +++ b/cpp/src/arrow/flight/server.h @@ -29,6 +29,7 @@ #include "arrow/flight/types.h" #include "arrow/ipc/dictionary.h" +#include "arrow/record_batch.h" namespace arrow { @@ -81,6 +82,13 @@ class ARROW_EXPORT RecordBatchStream : public FlightDataStream { std::shared_ptr reader_; }; +/// \brief A reader for IPC payloads uploaded by a client +class ARROW_EXPORT FlightMessageReader : public RecordBatchReader { + public: + /// \brief Get the descriptor for this upload. + virtual const FlightDescriptor& descriptor() const = 0; +}; + /// \brief Skeleton RPC server implementation which can be used to create /// custom servers by implementing its abstract methods class ARROW_EXPORT FlightServerBase { @@ -125,7 +133,10 @@ class ARROW_EXPORT FlightServerBase { /// \return Status virtual Status DoGet(const Ticket& request, std::unique_ptr* stream); - // virtual Status DoPut(std::unique_ptr* reader) = 0; + /// \brief Process a stream of IPC payloads sent from a client + /// \param[in] reader a sequence of uploaded record batches + /// \return Status + virtual Status DoPut(std::unique_ptr reader); /// \brief Execute an action, return stream of zero or more results /// \param[in] action the action to execute, with type and body diff --git a/cpp/src/arrow/flight/test-integration-client.cc b/cpp/src/arrow/flight/test-integration-client.cc index 267025a451cc7..a94001ff33713 100644 --- a/cpp/src/arrow/flight/test-integration-client.cc +++ b/cpp/src/arrow/flight/test-integration-client.cc @@ -31,6 +31,7 @@ #include "arrow/io/test-common.h" #include "arrow/ipc/json.h" #include "arrow/record_batch.h" +#include "arrow/table.h" #include "arrow/flight/server.h" #include "arrow/flight/test-util.h" @@ -38,7 +39,6 @@ DEFINE_string(host, "localhost", "Server port to connect to"); DEFINE_int32(port, 31337, "Server port to connect to"); DEFINE_string(path, "", "Resource path to request"); -DEFINE_string(output, "", "Where to write requested resource"); int main(int argc, char** argv) { gflags::SetUsageMessage("Integration testing client for Flight."); @@ -49,9 +49,38 @@ int main(int argc, char** argv) { arrow::flight::FlightDescriptor descr{ arrow::flight::FlightDescriptor::PATH, "", {FLAGS_path}}; + + // 1. Put the data to the server. + std::unique_ptr reader; + std::shared_ptr in_file; + std::cout << "Opening JSON file '" << FLAGS_path << "'" << std::endl; + ABORT_NOT_OK(arrow::io::ReadableFile::Open(FLAGS_path, &in_file)); + + int64_t file_size = 0; + ABORT_NOT_OK(in_file->GetSize(&file_size)); + + std::shared_ptr json_buffer; + ABORT_NOT_OK(in_file->Read(file_size, &json_buffer)); + + ABORT_NOT_OK(arrow::ipc::internal::json::JsonReader::Open(json_buffer, &reader)); + + std::unique_ptr write_stream; + ABORT_NOT_OK(client->DoPut(descr, reader->schema(), &write_stream)); + + std::vector> original_chunks; + for (int i = 0; i < reader->num_record_batches(); i++) { + std::shared_ptr batch; + ABORT_NOT_OK(reader->ReadRecordBatch(i, &batch)); + original_chunks.push_back(batch); + ABORT_NOT_OK(write_stream->WriteRecordBatch(*batch)); + } + ABORT_NOT_OK(write_stream->Close()); + + // 2. Get the ticket for the data. std::unique_ptr info; ABORT_NOT_OK(client->GetFlightInfo(descr, &info)); + // 3. Download the data from the server. std::shared_ptr schema; ABORT_NOT_OK(info->GetSchema(&schema)); @@ -64,19 +93,28 @@ int main(int argc, char** argv) { std::unique_ptr stream; ABORT_NOT_OK(client->DoGet(ticket, schema, &stream)); - std::shared_ptr out_file; - ABORT_NOT_OK(arrow::io::FileOutputStream::Open(FLAGS_output, &out_file)); - std::shared_ptr writer; - ABORT_NOT_OK(arrow::ipc::RecordBatchFileWriter::Open(out_file.get(), schema, &writer)); - + std::vector> retrieved_chunks; std::shared_ptr chunk; while (true) { ABORT_NOT_OK(stream->ReadNext(&chunk)); if (chunk == nullptr) break; - ABORT_NOT_OK(writer->WriteRecordBatch(*chunk)); + retrieved_chunks.push_back(chunk); } - ABORT_NOT_OK(writer->Close()); + // 4. Validate that the data is equal. + + std::shared_ptr original_data; + std::shared_ptr retrieved_data; + + ABORT_NOT_OK( + arrow::Table::FromRecordBatches(reader->schema(), original_chunks, &original_data)); + ABORT_NOT_OK( + arrow::Table::FromRecordBatches(schema, retrieved_chunks, &retrieved_data)); + + if (!original_data->Equals(*retrieved_data)) { + std::cerr << "Data does not match!" << std::endl; + return 1; + } return 0; } diff --git a/cpp/src/arrow/flight/test-integration-server.cc b/cpp/src/arrow/flight/test-integration-server.cc index 80813e7f19a4c..c50e52713b33f 100644 --- a/cpp/src/arrow/flight/test-integration-server.cc +++ b/cpp/src/arrow/flight/test-integration-server.cc @@ -27,6 +27,7 @@ #include "arrow/io/test-common.h" #include "arrow/ipc/json.h" #include "arrow/record_batch.h" +#include "arrow/table.h" #include "arrow/flight/server.h" #include "arrow/flight/test-util.h" @@ -36,57 +37,7 @@ DEFINE_int32(port, 31337, "Server port to listen on"); namespace arrow { namespace flight { -class JsonReaderRecordBatchStream : public FlightDataStream { - public: - explicit JsonReaderRecordBatchStream( - std::unique_ptr&& reader) - : index_(0), pool_(default_memory_pool()), reader_(std::move(reader)) {} - - std::shared_ptr schema() override { return reader_->schema(); } - - Status Next(ipc::internal::IpcPayload* payload) override { - if (index_ >= reader_->num_record_batches()) { - // Signal that iteration is over - payload->metadata = nullptr; - return Status::OK(); - } - - std::shared_ptr batch; - RETURN_NOT_OK(reader_->ReadRecordBatch(index_, &batch)); - index_++; - - if (!batch) { - // Signal that iteration is over - payload->metadata = nullptr; - return Status::OK(); - } else { - return ipc::internal::GetRecordBatchPayload(*batch, pool_, payload); - } - } - - private: - int index_; - MemoryPool* pool_; - std::unique_ptr reader_; -}; - class FlightIntegrationTestServer : public FlightServerBase { - Status ReadJson(const std::string& json_path, - std::unique_ptr* out) { - std::shared_ptr in_file; - std::cout << "Opening JSON file '" << json_path << "'" << std::endl; - RETURN_NOT_OK(io::ReadableFile::Open(json_path, &in_file)); - - int64_t file_size = 0; - RETURN_NOT_OK(in_file->GetSize(&file_size)); - - std::shared_ptr json_buffer; - RETURN_NOT_OK(in_file->Read(file_size, &json_buffer)); - - RETURN_NOT_OK(arrow::ipc::internal::json::JsonReader::Open(json_buffer, out)); - return Status::OK(); - } - Status GetFlightInfo(const FlightDescriptor& request, std::unique_ptr* info) override { if (request.type == FlightDescriptor::PATH) { @@ -94,16 +45,19 @@ class FlightIntegrationTestServer : public FlightServerBase { return Status::Invalid("Invalid path"); } - std::unique_ptr reader; - RETURN_NOT_OK(ReadJson(request.path.back(), &reader)); + auto data = uploaded_chunks.find(request.path[0]); + if (data == uploaded_chunks.end()) { + return Status::KeyError("Could not find flight."); + } + auto flight = data->second; - FlightEndpoint endpoint1({{request.path.back()}, {}}); + FlightEndpoint endpoint1({{request.path[0]}, {}}); FlightInfo::Data flight_data; - RETURN_NOT_OK(internal::SchemaToString(*reader->schema(), &flight_data.schema)); + RETURN_NOT_OK(internal::SchemaToString(*flight->schema(), &flight_data.schema)); flight_data.descriptor = request; flight_data.endpoints = {endpoint1}; - flight_data.total_records = reader->num_record_batches(); + flight_data.total_records = flight->num_rows(); flight_data.total_bytes = -1; FlightInfo value(flight_data); @@ -116,14 +70,44 @@ class FlightIntegrationTestServer : public FlightServerBase { Status DoGet(const Ticket& request, std::unique_ptr* data_stream) override { - std::unique_ptr reader; - RETURN_NOT_OK(ReadJson(request.ticket, &reader)); + auto data = uploaded_chunks.find(request.ticket); + if (data == uploaded_chunks.end()) { + return Status::KeyError("Could not find flight."); + } + auto flight = data->second; - *data_stream = std::unique_ptr( - new JsonReaderRecordBatchStream(std::move(reader))); + *data_stream = std::unique_ptr(new RecordBatchStream( + std::shared_ptr(new TableBatchReader(*flight)))); return Status::OK(); } + + Status DoPut(std::unique_ptr reader) override { + const FlightDescriptor& descriptor = reader->descriptor(); + + if (descriptor.type != FlightDescriptor::DescriptorType::PATH) { + return Status::Invalid("Must specify a path"); + } else if (descriptor.path.size() < 1) { + return Status::Invalid("Must specify a path"); + } + + std::string key = descriptor.path[0]; + + std::vector> retrieved_chunks; + std::shared_ptr chunk; + while (true) { + RETURN_NOT_OK(reader->ReadNext(&chunk)); + if (chunk == nullptr) break; + retrieved_chunks.push_back(chunk); + } + std::shared_ptr retrieved_data; + RETURN_NOT_OK(arrow::Table::FromRecordBatches(reader->schema(), retrieved_chunks, + &retrieved_data)); + uploaded_chunks[key] = retrieved_data; + return Status::OK(); + } + + std::unordered_map> uploaded_chunks; }; } // namespace flight diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index 0362105bbc592..e4251bdd5d21b 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -152,12 +152,6 @@ class FlightInfo { mutable bool reconstructed_schema_; }; -// TODO(wesm): NYI -class ARROW_EXPORT FlightPutWriter { - public: - virtual ~FlightPutWriter() = default; -}; - /// \brief An iterator to FlightInfo instances returned by ListFlights class ARROW_EXPORT FlightListing { public: diff --git a/integration/integration_test.py b/integration/integration_test.py index 0bced26f15acd..cef4e5697b29e 100644 --- a/integration/integration_test.py +++ b/integration/integration_test.py @@ -1009,23 +1009,10 @@ def _compare_flight_implementations(self, producer, consumer): print('Testing file {0}'.format(json_path)) print('==========================================================') - name = os.path.splitext(os.path.basename(json_path))[0] - - file_id = guid()[:8] - with producer.flight_server(): - # Have the client request the file - consumer_file_path = os.path.join( - self.temp_dir, - file_id + '_' + name + '.consumer_requested_file') - consumer.flight_request(producer.FLIGHT_PORT, - json_path, consumer_file_path) - - # Validate the file - print('-- Validating file') - consumer.validate(json_path, consumer_file_path) - - # TODO: also have the client upload the file + # Have the client upload the file, then download and + # compare + consumer.flight_request(producer.FLIGHT_PORT, json_path) class Tester(object): @@ -1053,7 +1040,7 @@ def validate(self, json_path, arrow_path): def flight_server(self): raise NotImplementedError - def flight_request(self, port, json_path, arrow_path): + def flight_request(self, port, json_path): raise NotImplementedError @@ -1122,12 +1109,11 @@ def file_to_stream(self, file_path, stream_path): print(' '.join(cmd)) run_cmd(cmd) - def flight_request(self, port, json_path, arrow_path): + def flight_request(self, port, json_path): cmd = ['java', '-cp', self.ARROW_FLIGHT_JAR, self.ARROW_FLIGHT_CLIENT, '-port', str(port), - '-j', json_path, - '-a', arrow_path] + '-j', json_path] if self.debug: print(' '.join(cmd)) run_cmd(cmd) @@ -1230,15 +1216,14 @@ def flight_server(self): server.terminate() server.wait(5) - def flight_request(self, port, json_path, arrow_path): + def flight_request(self, port, json_path): cmd = self.FLIGHT_CLIENT_CMD + [ '-port=' + str(port), '-path=' + json_path, - '-output=' + arrow_path ] if self.debug: print(' '.join(cmd)) - subprocess.run(cmd) + run_cmd(cmd) class JSTester(Tester): diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java b/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java index af7445fb416fd..dd06f79479064 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java @@ -145,6 +145,7 @@ public void listActions(StreamListener listener) { listener.onNext(new ActionType("get", "pull a stream. Action must be done via standard get mechanism")); listener.onNext(new ActionType("put", "push a stream. Action must be done via standard get mechanism")); listener.onNext(new ActionType("drop", "delete a flight. Action body is a JSON encoded path.")); + listener.onCompleted(); } @Override diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java index 803a56c6c1afe..3c04a19c9de8d 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java @@ -18,7 +18,6 @@ package org.apache.arrow.flight.example.integration; import java.io.File; -import java.io.FileOutputStream; import java.io.IOException; import java.util.List; @@ -30,9 +29,11 @@ import org.apache.arrow.flight.Location; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorLoader; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.dictionary.DictionaryProvider; -import org.apache.arrow.vector.ipc.ArrowFileWriter; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.JsonFileReader; +import org.apache.arrow.vector.util.Validator; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.CommandLineParser; import org.apache.commons.cli.DefaultParser; @@ -48,7 +49,6 @@ class IntegrationTestClient { private IntegrationTestClient() { options = new Options(); - options.addOption("a", "arrow", true, "arrow file"); options.addOption("j", "json", true, "json file"); options.addOption("host", true, "The host to connect to."); options.addOption("port", true, "The port to connect to." ); @@ -64,7 +64,7 @@ public static void main(String[] args) { } } - static void fatalError(String message, Throwable e) { + private static void fatalError(String message, Throwable e) { System.err.println(message); System.err.println(e.getMessage()); LOGGER.error(message, e); @@ -72,37 +72,55 @@ static void fatalError(String message, Throwable e) { } private void run(String[] args) throws ParseException, IOException { - CommandLineParser parser = new DefaultParser(); - CommandLine cmd = parser.parse(options, args, false); - - String fileName = cmd.getOptionValue("arrow"); - if (fileName == null) { - throw new IllegalArgumentException("missing arrow file parameter"); - } - File arrowFile = new File(fileName); - if (arrowFile.exists()) { - throw new IllegalArgumentException("arrow file already exists: " + arrowFile.getAbsolutePath()); - } + final CommandLineParser parser = new DefaultParser(); + final CommandLine cmd = parser.parse(options, args, false); final String host = cmd.getOptionValue("host", "localhost"); final int port = Integer.parseInt(cmd.getOptionValue("port", "31337")); final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); - FlightClient client = new FlightClient(allocator, new Location(host, port)); - FlightInfo info = client.getInfo(FlightDescriptor.path(cmd.getOptionValue("json"))); + final FlightClient client = new FlightClient(allocator, new Location(host, port)); + + final String inputPath = cmd.getOptionValue("j"); + + // 1. Read data from JSON and upload to server. + FlightDescriptor descriptor = FlightDescriptor.path(inputPath); + VectorSchemaRoot jsonRoot; + try (JsonFileReader reader = new JsonFileReader(new File(inputPath), allocator); + VectorSchemaRoot root = VectorSchemaRoot.create(reader.start(), allocator)) { + jsonRoot = VectorSchemaRoot.create(root.getSchema(), allocator); + VectorUnloader unloader = new VectorUnloader(root); + VectorLoader jsonLoader = new VectorLoader(jsonRoot); + FlightClient.ClientStreamListener stream = client.startPut(descriptor, root); + while (reader.read(root)) { + stream.putNext(); + jsonLoader.load(unloader.getRecordBatch()); + root.clear(); + } + stream.completed(); + // Need to call this, or exceptions from the server get swallowed + stream.getResult(); + } + + // 2. Get the ticket for the data. + FlightInfo info = client.getInfo(descriptor); List endpoints = info.getEndpoints(); if (endpoints.isEmpty()) { throw new RuntimeException("No endpoints returned from Flight server."); } + // 3. Download the data from the server. FlightStream stream = client.getStream(info.getEndpoints().get(0).getTicket()); - try (VectorSchemaRoot root = stream.getRoot(); - FileOutputStream fileOutputStream = new FileOutputStream(arrowFile); - ArrowFileWriter arrowWriter = new ArrowFileWriter(root, new DictionaryProvider.MapDictionaryProvider(), - fileOutputStream.getChannel())) { + VectorSchemaRoot downloadedRoot; + try (VectorSchemaRoot root = stream.getRoot()) { + downloadedRoot = VectorSchemaRoot.create(root.getSchema(), allocator); + VectorLoader loader = new VectorLoader(downloadedRoot); + VectorUnloader unloader = new VectorUnloader(root); while (stream.next()) { - arrowWriter.writeBatch(); + loader.load(unloader.getRecordBatch()); } } + + Validator.compareVectorSchemaRoot(jsonRoot, downloadedRoot); } } diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java index 2b78cca93aaef..eff2f5d4126cc 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java @@ -17,34 +17,11 @@ package org.apache.arrow.flight.example.integration; -import java.io.File; -import java.io.FileOutputStream; -import java.nio.charset.StandardCharsets; -import java.util.Collections; -import java.util.concurrent.Callable; - -import org.apache.arrow.flight.Action; -import org.apache.arrow.flight.ActionType; -import org.apache.arrow.flight.Criteria; -import org.apache.arrow.flight.FlightDescriptor; -import org.apache.arrow.flight.FlightEndpoint; -import org.apache.arrow.flight.FlightInfo; -import org.apache.arrow.flight.FlightProducer; -import org.apache.arrow.flight.FlightServer; -import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.Location; -import org.apache.arrow.flight.Result; -import org.apache.arrow.flight.Ticket; -import org.apache.arrow.flight.auth.ServerAuthHandler; -import org.apache.arrow.flight.impl.Flight; +import org.apache.arrow.flight.example.ExampleFlightServer; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.VectorUnloader; -import org.apache.arrow.vector.dictionary.DictionaryProvider; -import org.apache.arrow.vector.ipc.ArrowFileWriter; -import org.apache.arrow.vector.ipc.JsonFileReader; -import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.util.AutoCloseables; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.CommandLineParser; import org.apache.commons.cli.DefaultParser; @@ -52,6 +29,7 @@ import org.apache.commons.cli.ParseException; class IntegrationTestServer { + private static final org.slf4j.Logger LOGGER = org.slf4j.LoggerFactory.getLogger(IntegrationTestServer.class); private final Options options; private IntegrationTestServer() { @@ -62,17 +40,25 @@ private IntegrationTestServer() { private void run(String[] args) throws Exception { CommandLineParser parser = new DefaultParser(); CommandLine cmd = parser.parse(options, args, false); + final int port = Integer.parseInt(cmd.getOptionValue("port", "31337")); final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); - final int port = Integer.parseInt(cmd.getOptionValue("port", "31337")); - try (final IntegrationFlightProducer producer = new IntegrationFlightProducer(allocator); - final FlightServer server = new FlightServer(allocator, port, producer, ServerAuthHandler.NO_OP)) { - server.start(); - // Print out message for integration test script - System.out.println("Server listening on localhost:" + server.getPort()); - while (true) { - Thread.sleep(30000); + final ExampleFlightServer efs = new ExampleFlightServer(allocator, new Location("localhost", port)); + efs.start(); + // Print out message for integration test script + System.out.println("Server listening on localhost:" + port); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + try { + System.out.println("\nExiting..."); + AutoCloseables.close(efs, allocator); + } catch (Exception e) { + e.printStackTrace(); } + })); + + while (true) { + Thread.sleep(30000); } } @@ -80,104 +66,17 @@ public static void main(String[] args) { try { new IntegrationTestServer().run(args); } catch (ParseException e) { - IntegrationTestClient.fatalError("Error parsing arguments", e); + fatalError("Error parsing arguments", e); } catch (Exception e) { - IntegrationTestClient.fatalError("Runtime error", e); + fatalError("Runtime error", e); } } - static class IntegrationFlightProducer implements FlightProducer, AutoCloseable { - private final BufferAllocator allocator; - - IntegrationFlightProducer(BufferAllocator allocator) { - this.allocator = allocator; - } - - @Override - public void close() { - allocator.close(); - } - - @Override - public void getStream(Ticket ticket, ServerStreamListener listener) { - String path = new String(ticket.getBytes(), StandardCharsets.UTF_8); - File inputFile = new File(path); - try (JsonFileReader reader = new JsonFileReader(inputFile, allocator)) { - Schema schema = reader.start(); - try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { - listener.start(root); - while (reader.read(root)) { - listener.putNext(); - } - listener.completed(); - } - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public void listFlights(Criteria criteria, StreamListener listener) { - listener.onCompleted(); - } - - @Override - public FlightInfo getFlightInfo(FlightDescriptor descriptor) { - if (descriptor.isCommand()) { - throw new UnsupportedOperationException("Commands not supported."); - } - if (descriptor.getPath().size() < 1) { - throw new IllegalArgumentException("Must provide a path."); - } - String path = descriptor.getPath().get(0); - File inputFile = new File(path); - try (JsonFileReader reader = new JsonFileReader(inputFile, allocator)) { - Schema schema = reader.start(); - return new FlightInfo(schema, descriptor, - Collections.singletonList(new FlightEndpoint(new Ticket(path.getBytes()), - new Location("localhost", 31338))), - 0, 0); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public Callable acceptPut(FlightStream flightStream) { - return () -> { - if (flightStream.getDescriptor().isCommand()) { - throw new UnsupportedOperationException("Commands not supported."); - } - if (flightStream.getDescriptor().getPath().size() < 1) { - throw new IllegalArgumentException("Must provide a path."); - } - String path = flightStream.getDescriptor().getPath().get(0); - File outputFile = new File(path); - if (!outputFile.createNewFile()) { - throw new IllegalStateException("File already exists."); - } - try (VectorSchemaRoot root = flightStream.getRoot(); - FileOutputStream fileOutputStream = new FileOutputStream(outputFile); - ArrowFileWriter writer = new ArrowFileWriter(root, new DictionaryProvider.MapDictionaryProvider(), - fileOutputStream.getChannel())) { - writer.start(); - while (flightStream.next()) { - writer.writeBatch(); - } - writer.end(); - } - return Flight.PutResult.getDefaultInstance(); - }; - } - - @Override - public Result doAction(Action action) { - throw new UnsupportedOperationException("No actions implemented."); - } - - @Override - public void listActions(StreamListener listener) { - listener.onCompleted(); - } + private static void fatalError(String message, Throwable e) { + System.err.println(message); + System.err.println(e.getMessage()); + LOGGER.error(message, e); + System.exit(1); } + } From 3cb51badd430a634dee32f3b73026f5d72102604 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 29 Jan 2019 13:44:51 -0500 Subject: [PATCH 06/20] Test all returned locations in Flight integration tests --- .../arrow/flight/test-integration-client.cc | 79 +++++++++++-------- integration/integration_test.py | 10 +-- .../integration/IntegrationTestClient.java | 36 ++++++--- 3 files changed, 75 insertions(+), 50 deletions(-) diff --git a/cpp/src/arrow/flight/test-integration-client.cc b/cpp/src/arrow/flight/test-integration-client.cc index a94001ff33713..64f9dc3cd158e 100644 --- a/cpp/src/arrow/flight/test-integration-client.cc +++ b/cpp/src/arrow/flight/test-integration-client.cc @@ -15,12 +15,11 @@ // specific language governing permissions and limitations // under the License. -// Client implementation for Flight integration testing. Requests the given -// path from the Flight server, which reads that file and sends it as a stream -// to the client. The client writes the server stream to the IPC file format at -// the given output file path. The integration test script then uses the -// existing integration test tools to compare the output binary with the -// original JSON +// Client implementation for Flight integration testing. Loads +// RecordBatches from the given JSON file and uploads them to the +// Flight server, which stores the data and schema in memory. The +// client then requests the data from the server and compares it to +// the data originally uploaded. #include #include @@ -76,11 +75,14 @@ int main(int argc, char** argv) { } ABORT_NOT_OK(write_stream->Close()); + std::shared_ptr original_data; + ABORT_NOT_OK( + arrow::Table::FromRecordBatches(reader->schema(), original_chunks, &original_data)); + // 2. Get the ticket for the data. std::unique_ptr info; ABORT_NOT_OK(client->GetFlightInfo(descr, &info)); - // 3. Download the data from the server. std::shared_ptr schema; ABORT_NOT_OK(info->GetSchema(&schema)); @@ -89,32 +91,43 @@ int main(int argc, char** argv) { return -1; } - arrow::flight::Ticket ticket = info->endpoints()[0].ticket; - std::unique_ptr stream; - ABORT_NOT_OK(client->DoGet(ticket, schema, &stream)); - - std::vector> retrieved_chunks; - std::shared_ptr chunk; - while (true) { - ABORT_NOT_OK(stream->ReadNext(&chunk)); - if (chunk == nullptr) break; - retrieved_chunks.push_back(chunk); + for (const arrow::flight::FlightEndpoint& endpoint : info->endpoints()) { + const auto& ticket = endpoint.ticket; + + auto locations = endpoint.locations; + if (locations.size() == 0) { + locations = {arrow::flight::Location{FLAGS_host, FLAGS_port}}; + } + + for (const auto location : locations) { + std::cout << "Verifying location " << location.host << ':' << location.port + << std::endl; + // 3. Download the data from the server. + std::unique_ptr read_client; + ABORT_NOT_OK(arrow::flight::FlightClient::Connect(location.host, location.port, + &read_client)); + + std::unique_ptr stream; + ABORT_NOT_OK(read_client->DoGet(ticket, schema, &stream)); + + std::vector> retrieved_chunks; + std::shared_ptr chunk; + while (true) { + ABORT_NOT_OK(stream->ReadNext(&chunk)); + if (chunk == nullptr) break; + retrieved_chunks.push_back(chunk); + } + + // 4. Validate that the data is equal. + std::shared_ptr retrieved_data; + ABORT_NOT_OK( + arrow::Table::FromRecordBatches(schema, retrieved_chunks, &retrieved_data)); + + if (!original_data->Equals(*retrieved_data)) { + std::cerr << "Data does not match!" << std::endl; + return 1; + } + } } - - // 4. Validate that the data is equal. - - std::shared_ptr original_data; - std::shared_ptr retrieved_data; - - ABORT_NOT_OK( - arrow::Table::FromRecordBatches(reader->schema(), original_chunks, &original_data)); - ABORT_NOT_OK( - arrow::Table::FromRecordBatches(schema, retrieved_chunks, &retrieved_data)); - - if (!original_data->Equals(*retrieved_data)) { - std::cerr << "Data does not match!" << std::endl; - return 1; - } - return 0; } diff --git a/integration/integration_test.py b/integration/integration_test.py index cef4e5697b29e..fc02d0712006c 100644 --- a/integration/integration_test.py +++ b/integration/integration_test.py @@ -1004,12 +1004,12 @@ def _compare_flight_implementations(self, producer, consumer): ) print('##########################################################') - for json_path in self.json_files: - print('==========================================================') - print('Testing file {0}'.format(json_path)) - print('==========================================================') + with producer.flight_server(): + for json_path in self.json_files: + print('==========================================================') + print('Testing file {0}'.format(json_path)) + print('==========================================================') - with producer.flight_server(): # Have the client upload the file, then download and # compare consumer.flight_request(producer.FLIGHT_PORT, json_path) diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java index 3c04a19c9de8d..ed450074a767a 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java @@ -19,6 +19,7 @@ import java.io.File; import java.io.IOException; +import java.util.Collections; import java.util.List; import org.apache.arrow.flight.FlightClient; @@ -51,7 +52,7 @@ private IntegrationTestClient() { options = new Options(); options.addOption("j", "json", true, "json file"); options.addOption("host", true, "The host to connect to."); - options.addOption("port", true, "The port to connect to." ); + options.addOption("port", true, "The port to connect to."); } public static void main(String[] args) { @@ -109,18 +110,29 @@ private void run(String[] args) throws ParseException, IOException { throw new RuntimeException("No endpoints returned from Flight server."); } - // 3. Download the data from the server. - FlightStream stream = client.getStream(info.getEndpoints().get(0).getTicket()); - VectorSchemaRoot downloadedRoot; - try (VectorSchemaRoot root = stream.getRoot()) { - downloadedRoot = VectorSchemaRoot.create(root.getSchema(), allocator); - VectorLoader loader = new VectorLoader(downloadedRoot); - VectorUnloader unloader = new VectorUnloader(root); - while (stream.next()) { - loader.load(unloader.getRecordBatch()); + for (FlightEndpoint endpoint : info.getEndpoints()) { + // 3. Download the data from the server. + List locations = endpoint.getLocations(); + if (locations.size() == 0) { + locations = Collections.singletonList(new Location(host, port)); } - } + for (Location location : locations) { + System.out.println("Verifying location " + location.getHost() + ":" + location.getPort()); + FlightClient readClient = new FlightClient(allocator, location); + FlightStream stream = readClient.getStream(endpoint.getTicket()); + VectorSchemaRoot downloadedRoot; + try (VectorSchemaRoot root = stream.getRoot()) { + downloadedRoot = VectorSchemaRoot.create(root.getSchema(), allocator); + VectorLoader loader = new VectorLoader(downloadedRoot); + VectorUnloader unloader = new VectorUnloader(root); + while (stream.next()) { + loader.load(unloader.getRecordBatch()); + } + } - Validator.compareVectorSchemaRoot(jsonRoot, downloadedRoot); + // 4. Validate the data. + Validator.compareVectorSchemaRoot(jsonRoot, downloadedRoot); + } + } } } From 111b3e6b1b9f5507df9b8efb91afbac8c86cef2d Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 29 Jan 2019 13:44:33 -0500 Subject: [PATCH 07/20] Fix style/lint issues --- cpp/src/arrow/flight/client.cc | 2 +- cpp/src/arrow/flight/client.h | 3 ++- cpp/src/arrow/flight/server.h | 5 ++--- integration/integration_test.py | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 99f88d08a843e..f748d498e8f22 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -123,7 +123,7 @@ class FlightStreamWriter : public ipc::RecordBatchWriter { : rpc_{std::move(rpc)}, descriptor_{descriptor}, schema_{schema}, - pool_{default_memory_pool()} {}; + pool_{default_memory_pool()} {} Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override { IpcPayload payload; diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index e548f7c76e848..61c357e561f1b 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -92,7 +92,7 @@ class ARROW_EXPORT FlightClient { /// \brief Given a flight ticket and schema, request to be sent the /// stream. Returns record batch stream reader - /// \param[in] ticket + /// \param[in] ticket The flight ticket to use /// \param[in] schema the schema of the stream data as computed by /// GetFlightInfo /// \param[out] stream the returned RecordBatchReader @@ -102,6 +102,7 @@ class ARROW_EXPORT FlightClient { /// \brief Upload data to a Flight described by the given descriptor. /// \param[in] descriptor the descriptor of the stream + /// \param[in] schema the schema for the data to upload /// \param[out] stream a writer to write record batches to /// \return Status Status DoPut(const FlightDescriptor& descriptor, const std::shared_ptr& schema, diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h index f975b8619cd48..b2e8b02be8e7d 100644 --- a/cpp/src/arrow/flight/server.h +++ b/cpp/src/arrow/flight/server.h @@ -69,9 +69,9 @@ class ARROW_EXPORT FlightDataStream { /// \brief A basic implementation of FlightDataStream that will provide /// a sequence of FlightData messages to be written to a gRPC stream -/// \param[in] reader produces a sequence of record batches class ARROW_EXPORT RecordBatchStream : public FlightDataStream { public: + /// \param[in] reader produces a sequence of record batches explicit RecordBatchStream(const std::shared_ptr& reader); std::shared_ptr schema() override; @@ -98,8 +98,7 @@ class ARROW_EXPORT FlightServerBase { /// \brief Run an insecure server on localhost at the indicated port. Block /// until server is shut down or otherwise terminates - /// \param[in] port - /// \return Status + /// \param[in] port the port to bind to void Run(int port); /// \brief Shut down the server. Can be called from signal handler or another diff --git a/integration/integration_test.py b/integration/integration_test.py index fc02d0712006c..e7e8edda6ddf1 100644 --- a/integration/integration_test.py +++ b/integration/integration_test.py @@ -1006,9 +1006,9 @@ def _compare_flight_implementations(self, producer, consumer): with producer.flight_server(): for json_path in self.json_files: - print('==========================================================') + print('=' * 58) print('Testing file {0}'.format(json_path)) - print('==========================================================') + print('=' * 58) # Have the client upload the file, then download and # compare From 3e185cb96f1a1a1766c3cc5b51f844e02ae6b343 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 1 Feb 2019 11:59:15 -0500 Subject: [PATCH 08/20] [ARROW-4409] Add convenience to parse JSON from file --- cpp/src/arrow/flight/test-integration-client.cc | 12 +++--------- cpp/src/arrow/ipc/json.cc | 13 +++++++++++++ cpp/src/arrow/ipc/json.h | 13 +++++++++++++ 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/cpp/src/arrow/flight/test-integration-client.cc b/cpp/src/arrow/flight/test-integration-client.cc index 64f9dc3cd158e..7952668bfdbcd 100644 --- a/cpp/src/arrow/flight/test-integration-client.cc +++ b/cpp/src/arrow/flight/test-integration-client.cc @@ -51,17 +51,11 @@ int main(int argc, char** argv) { // 1. Put the data to the server. std::unique_ptr reader; - std::shared_ptr in_file; std::cout << "Opening JSON file '" << FLAGS_path << "'" << std::endl; + std::shared_ptr in_file; ABORT_NOT_OK(arrow::io::ReadableFile::Open(FLAGS_path, &in_file)); - - int64_t file_size = 0; - ABORT_NOT_OK(in_file->GetSize(&file_size)); - - std::shared_ptr json_buffer; - ABORT_NOT_OK(in_file->Read(file_size, &json_buffer)); - - ABORT_NOT_OK(arrow::ipc::internal::json::JsonReader::Open(json_buffer, &reader)); + ABORT_NOT_OK(arrow::ipc::internal::json::JsonReader::Open(arrow::default_memory_pool(), + in_file, &reader)); std::unique_ptr write_stream; ABORT_NOT_OK(client->DoPut(descr, reader->schema(), &write_stream)); diff --git a/cpp/src/arrow/ipc/json.cc b/cpp/src/arrow/ipc/json.cc index 61c242ca2dbbb..56fe31cbf5691 100644 --- a/cpp/src/arrow/ipc/json.cc +++ b/cpp/src/arrow/ipc/json.cc @@ -22,6 +22,7 @@ #include #include "arrow/buffer.h" +#include "arrow/io/file.h" #include "arrow/ipc/json-internal.h" #include "arrow/memory_pool.h" #include "arrow/record_batch.h" @@ -157,6 +158,18 @@ Status JsonReader::Open(MemoryPool* pool, const std::shared_ptr& data, return (*reader)->impl_->ParseAndReadSchema(); } +Status JsonReader::Open(MemoryPool* pool, + const std::shared_ptr& in_file, + std::unique_ptr* reader) { + int64_t file_size = 0; + RETURN_NOT_OK(in_file->GetSize(&file_size)); + + std::shared_ptr json_buffer; + RETURN_NOT_OK(in_file->Read(file_size, &json_buffer)); + + return Open(pool, json_buffer, reader); +} + std::shared_ptr JsonReader::schema() const { return impl_->schema(); } int JsonReader::num_record_batches() const { return impl_->num_record_batches(); } diff --git a/cpp/src/arrow/ipc/json.h b/cpp/src/arrow/ipc/json.h index 5c00555de8ec0..aeed7070fe96e 100644 --- a/cpp/src/arrow/ipc/json.h +++ b/cpp/src/arrow/ipc/json.h @@ -33,6 +33,10 @@ class MemoryPool; class RecordBatch; class Schema; +namespace io { +class ReadableFile; +} // namespace io + namespace ipc { namespace internal { namespace json { @@ -95,6 +99,15 @@ class ARROW_EXPORT JsonReader { static Status Open(const std::shared_ptr& data, std::unique_ptr* reader); + /// \brief Create a new JSON reader from a file + /// + /// \param[in] pool a MemoryPool to use for buffer allocations + /// \param[in] in_file a ReadableFile containing JSON data + /// \param[out] reader the returned reader object + /// \return Status + static Status Open(MemoryPool* pool, const std::shared_ptr& in_file, + std::unique_ptr* reader); + /// \brief Return the schema read from the JSON std::shared_ptr schema() const; From 65d6ba2ff41a86381a7bf9d9caab86b620f32d48 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 1 Feb 2019 14:32:08 -0500 Subject: [PATCH 09/20] Clean up C++ Flight integration client --- cpp/src/arrow/flight/client.cc | 15 +-- cpp/src/arrow/flight/client.h | 4 +- cpp/src/arrow/flight/server.cc | 10 +- .../arrow/flight/test-integration-client.cc | 93 ++++++++++++------- .../arrow/flight/test-integration-server.cc | 4 +- 5 files changed, 80 insertions(+), 46 deletions(-) diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index f748d498e8f22..64d9c82fa552b 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -117,13 +117,14 @@ class FlightClient; /// DoPut stream. class FlightStreamWriter : public ipc::RecordBatchWriter { public: - explicit FlightStreamWriter(std::unique_ptr&& rpc, + explicit FlightStreamWriter(std::unique_ptr rpc, const FlightDescriptor& descriptor, - const std::shared_ptr& schema) - : rpc_{std::move(rpc)}, - descriptor_{descriptor}, - schema_{schema}, - pool_{default_memory_pool()} {} + const std::shared_ptr& schema, + MemoryPool* pool = default_memory_pool()) + : rpc_(std::move(rpc)), + descriptor_(descriptor), + schema_(schema), + pool_(pool) {} Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override { IpcPayload payload; @@ -154,7 +155,7 @@ class FlightStreamWriter : public ipc::RecordBatchWriter { private: /// \brief Set the gRPC writer backing this Flight stream. /// \param [in] writer the gRPC writer - void set_stream(std::unique_ptr>&& writer) { + void set_stream(std::unique_ptr> writer) { writer_ = std::move(writer); } diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index 61c357e561f1b..730d8b10f63bb 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -100,7 +100,9 @@ class ARROW_EXPORT FlightClient { Status DoGet(const Ticket& ticket, const std::shared_ptr& schema, std::unique_ptr* stream); - /// \brief Upload data to a Flight described by the given descriptor. + /// \brief Upload data to a Flight described by the given + /// descriptor. The caller must call Close() on the returned stream + /// once they are done writing. /// \param[in] descriptor the descriptor of the stream /// \param[in] schema the schema for the data to upload /// \param[out] stream a writer to write record batches to diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index 77f665356f1db..1d5b215c8ead3 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -53,15 +53,15 @@ namespace flight { return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, MESSAGE); \ } -class ARROW_EXPORT FlightMessageReaderImpl : public FlightMessageReader { +class FlightMessageReaderImpl : public FlightMessageReader { public: FlightMessageReaderImpl(const FlightDescriptor& descriptor, std::shared_ptr schema, grpc::ServerReader* reader) - : descriptor_{descriptor}, - schema_{schema}, - reader_{reader}, - stream_finished_{false} {} + : descriptor_(descriptor), + schema_(schema), + reader_(reader), + stream_finished_(false) {} const FlightDescriptor& descriptor() const override { return descriptor_; } diff --git a/cpp/src/arrow/flight/test-integration-client.cc b/cpp/src/arrow/flight/test-integration-client.cc index 7952668bfdbcd..3954bc8523b94 100644 --- a/cpp/src/arrow/flight/test-integration-client.cc +++ b/cpp/src/arrow/flight/test-integration-client.cc @@ -39,6 +39,60 @@ DEFINE_string(host, "localhost", "Server port to connect to"); DEFINE_int32(port, 31337, "Server port to connect to"); DEFINE_string(path, "", "Resource path to request"); +/// \brief Helper to read a RecordBatchReader into a Table. +arrow::Status ReadToTable(std::unique_ptr& reader, + std::shared_ptr* retrieved_data) { + std::vector> retrieved_chunks; + std::shared_ptr chunk; + while (true) { + RETURN_NOT_OK(reader->ReadNext(&chunk)); + if (chunk == nullptr) break; + retrieved_chunks.push_back(chunk); + } + return arrow::Table::FromRecordBatches(reader->schema(), retrieved_chunks, + retrieved_data); +} + +/// \brief Helper to read a JsonReader into a Table. +arrow::Status ReadToTable(std::unique_ptr& reader, + std::shared_ptr* retrieved_data) { + std::vector> retrieved_chunks; + std::shared_ptr chunk; + for (int i = 0; i < reader->num_record_batches(); i++) { + RETURN_NOT_OK(reader->ReadRecordBatch(i, &chunk)); + retrieved_chunks.push_back(chunk); + } + return arrow::Table::FromRecordBatches(reader->schema(), retrieved_chunks, + retrieved_data); +} + +/// \brief Helper to copy a RecordBatchReader to a RecordBatchWriter. +arrow::Status CopyReaderToWriter(std::unique_ptr& reader, + std::unique_ptr& writer) { + while (true) { + std::shared_ptr chunk; + RETURN_NOT_OK(reader->ReadNext(&chunk)); + if (chunk == nullptr) break; + RETURN_NOT_OK(writer->WriteRecordBatch(*chunk)); + } + return writer->Close(); +} + +/// \brief Helper to read a flight into a Table. +arrow::Status ConsumeFlightLocation(const arrow::flight::Location& location, + const arrow::flight::Ticket& ticket, + const std::shared_ptr& schema, + std::shared_ptr* retrieved_data) { + std::unique_ptr read_client; + RETURN_NOT_OK( + arrow::flight::FlightClient::Connect(location.host, location.port, &read_client)); + + std::unique_ptr stream; + RETURN_NOT_OK(read_client->DoGet(ticket, schema, &stream)); + + return ReadToTable(stream, retrieved_data); +} + int main(int argc, char** argv) { gflags::SetUsageMessage("Integration testing client for Flight."); gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -57,21 +111,14 @@ int main(int argc, char** argv) { ABORT_NOT_OK(arrow::ipc::internal::json::JsonReader::Open(arrow::default_memory_pool(), in_file, &reader)); + std::shared_ptr original_data; + ABORT_NOT_OK(ReadToTable(reader, &original_data)); + std::unique_ptr write_stream; ABORT_NOT_OK(client->DoPut(descr, reader->schema(), &write_stream)); - - std::vector> original_chunks; - for (int i = 0; i < reader->num_record_batches(); i++) { - std::shared_ptr batch; - ABORT_NOT_OK(reader->ReadRecordBatch(i, &batch)); - original_chunks.push_back(batch); - ABORT_NOT_OK(write_stream->WriteRecordBatch(*batch)); - } - ABORT_NOT_OK(write_stream->Close()); - - std::shared_ptr original_data; - ABORT_NOT_OK( - arrow::Table::FromRecordBatches(reader->schema(), original_chunks, &original_data)); + std::unique_ptr table_reader( + new arrow::TableBatchReader(*original_data)); + ABORT_NOT_OK(CopyReaderToWriter(table_reader, write_stream)); // 2. Get the ticket for the data. std::unique_ptr info; @@ -97,26 +144,10 @@ int main(int argc, char** argv) { std::cout << "Verifying location " << location.host << ':' << location.port << std::endl; // 3. Download the data from the server. - std::unique_ptr read_client; - ABORT_NOT_OK(arrow::flight::FlightClient::Connect(location.host, location.port, - &read_client)); - - std::unique_ptr stream; - ABORT_NOT_OK(read_client->DoGet(ticket, schema, &stream)); - - std::vector> retrieved_chunks; - std::shared_ptr chunk; - while (true) { - ABORT_NOT_OK(stream->ReadNext(&chunk)); - if (chunk == nullptr) break; - retrieved_chunks.push_back(chunk); - } - - // 4. Validate that the data is equal. std::shared_ptr retrieved_data; - ABORT_NOT_OK( - arrow::Table::FromRecordBatches(schema, retrieved_chunks, &retrieved_data)); + ABORT_NOT_OK(ConsumeFlightLocation(location, ticket, schema, &retrieved_data)); + // 4. Validate that the data is equal. if (!original_data->Equals(*retrieved_data)) { std::cerr << "Data does not match!" << std::endl; return 1; diff --git a/cpp/src/arrow/flight/test-integration-server.cc b/cpp/src/arrow/flight/test-integration-server.cc index c50e52713b33f..7e201a031943d 100644 --- a/cpp/src/arrow/flight/test-integration-server.cc +++ b/cpp/src/arrow/flight/test-integration-server.cc @@ -47,7 +47,7 @@ class FlightIntegrationTestServer : public FlightServerBase { auto data = uploaded_chunks.find(request.path[0]); if (data == uploaded_chunks.end()) { - return Status::KeyError("Could not find flight."); + return Status::KeyError("Could not find flight.", request.path[0]); } auto flight = data->second; @@ -72,7 +72,7 @@ class FlightIntegrationTestServer : public FlightServerBase { std::unique_ptr* data_stream) override { auto data = uploaded_chunks.find(request.ticket); if (data == uploaded_chunks.end()) { - return Status::KeyError("Could not find flight."); + return Status::KeyError("Could not find flight.", request.ticket); } auto flight = data->second; From 562b86132a8e8e8464ad464b4cec1ec1c4c4474f Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 1 Feb 2019 15:06:35 -0500 Subject: [PATCH 10/20] Factor out FlightData->Message conversion --- cpp/src/arrow/flight/internal.cc | 16 ++++++++++++++++ cpp/src/arrow/flight/internal.h | 3 +++ cpp/src/arrow/flight/server.cc | 10 +++------- 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/cpp/src/arrow/flight/internal.cc b/cpp/src/arrow/flight/internal.cc index 629796ea36850..b614dd5b3ffc0 100644 --- a/cpp/src/arrow/flight/internal.cc +++ b/cpp/src/arrow/flight/internal.cc @@ -131,6 +131,22 @@ void ToProto(const Ticket& ticket, pb::Ticket* pb_ticket) { pb_ticket->set_ticket(ticket.ticket); } +// FlightData + +Status FromProto(const pb::FlightData& pb_data, + FlightDescriptor* descriptor, + std::unique_ptr* message) { + RETURN_NOT_OK(internal::FromProto(pb_data.flight_descriptor(), descriptor)); + const std::string& header = pb_data.data_header(); + const std::string& body = pb_data.data_body(); + std::shared_ptr header_buf = Buffer::Wrap(header.data(), header.size()); + std::shared_ptr body_buf = Buffer::Wrap(body.data(), body.size()); + if (header_buf == nullptr || body_buf == nullptr) { + return Status::UnknownError("Could not create buffers from protobuf"); + } + return ipc::Message::Open(header_buf, body_buf, message); +} + // FlightEndpoint Status FromProto(const pb::FlightEndpoint& pb_endpoint, FlightEndpoint* endpoint) { diff --git a/cpp/src/arrow/flight/internal.h b/cpp/src/arrow/flight/internal.h index bae1eedfa9c66..a4bafd2693df9 100644 --- a/cpp/src/arrow/flight/internal.h +++ b/cpp/src/arrow/flight/internal.h @@ -57,6 +57,9 @@ Status FromProto(const pb::Result& pb_result, Result* result); Status FromProto(const pb::Criteria& pb_criteria, Criteria* criteria); Status FromProto(const pb::Location& pb_location, Location* location); Status FromProto(const pb::Ticket& pb_ticket, Ticket* ticket); +Status FromProto(const pb::FlightData& pb_data, + FlightDescriptor* descriptor, + std::unique_ptr* message); Status FromProto(const pb::FlightDescriptor& pb_descr, FlightDescriptor* descr); Status FromProto(const pb::FlightEndpoint& pb_endpoint, FlightEndpoint* endpoint); Status FromProto(const pb::FlightGetInfo& pb_info, FlightInfo::Data* info); diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index 1d5b215c8ead3..bd01e9ab7d9a8 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -213,19 +213,15 @@ class FlightServiceImpl : public FlightService::Service { pb::FlightData data; if (reader->Read(&data)) { FlightDescriptor descriptor; + // Message only lives as long as data std::unique_ptr message; - std::shared_ptr schema; - GRPC_RETURN_NOT_OK(internal::FromProto(data.flight_descriptor(), &descriptor)); - std::shared_ptr header_buf; - std::shared_ptr body_buf; - GRPC_RETURN_NOT_OK(Buffer::FromString(data.data_header(), &header_buf)); - GRPC_RETURN_NOT_OK(Buffer::FromString(data.data_body(), &body_buf)); - GRPC_RETURN_NOT_OK(ipc::Message::Open(header_buf, body_buf, &message)); + GRPC_RETURN_NOT_OK(internal::FromProto(data, &descriptor, &message)); if (!message || message->type() != ipc::Message::Type::SCHEMA) { return internal::ToGrpcStatus( Status(StatusCode::Invalid, "DoPut must start with schema/descriptor")); } else { + std::shared_ptr schema; GRPC_RETURN_NOT_OK(ipc::ReadSchema(*message, &schema)); auto message_reader = std::unique_ptr( From 419ad688b82735a81abd990ca1c2cba4bbc196d1 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 1 Feb 2019 15:42:24 -0500 Subject: [PATCH 11/20] Log (de)serialization failures in Flight fast-path --- cpp/src/arrow/flight/serialization-internal.h | 48 +++++++++++++------ cpp/src/arrow/flight/server.cc | 4 -- 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/cpp/src/arrow/flight/serialization-internal.h b/cpp/src/arrow/flight/serialization-internal.h index e412e247012f1..4fa7238ddebfd 100644 --- a/cpp/src/arrow/flight/serialization-internal.h +++ b/cpp/src/arrow/flight/serialization-internal.h @@ -156,23 +156,42 @@ class GrpcBuffer : public arrow::MutableBuffer { grpc_slice slice_; }; +// Helper to log status code, as gRPC doesn't expose why +// (de)serialization fails +inline Status FailSerialization(Status status) { + if (!status.ok()) { + ARROW_LOG(WARNING) << "Error deserializing Flight message: " + << status.error_message(); + } + return status; +} + +inline arrow::Status FailSerialization(arrow::Status status) { + if (!status.ok()) { + ARROW_LOG(WARNING) << "Error deserializing Flight message: " + << status.ToString(); + } + return status; +} + // Read internal::FlightData from grpc::ByteBuffer containing FlightData // protobuf without copying template <> class SerializationTraits { public: static Status Serialize(const FlightData& msg, ByteBuffer** buffer, bool* own_buffer) { - return Status(StatusCode::UNIMPLEMENTED, - "internal::FlightData serialization not implemented"); + return FailSerialization( + Status(StatusCode::UNIMPLEMENTED, + "internal::FlightData serialization not implemented")); } static Status Deserialize(ByteBuffer* buffer, FlightData* out) { if (!buffer) { - return Status(StatusCode::INTERNAL, "No payload"); + return FailSerialization(Status(StatusCode::INTERNAL, "No payload")); } std::shared_ptr wrapped_buffer; - GRPC_RETURN_NOT_OK(GrpcBuffer::Wrap(buffer, &wrapped_buffer)); + GRPC_RETURN_NOT_OK(FailSerialization(GrpcBuffer::Wrap(buffer, &wrapped_buffer))); auto buffer_length = static_cast(wrapped_buffer->size()); CodedInputStream pb_stream(wrapped_buffer->data(), buffer_length); @@ -188,17 +207,20 @@ class SerializationTraits { case pb::FlightData::kFlightDescriptorFieldNumber: { pb::FlightDescriptor pb_descriptor; if (!pb_descriptor.ParseFromCodedStream(&pb_stream)) { - return Status(StatusCode::INTERNAL, "Unable to parse FlightDescriptor"); + return FailSerialization(Status(StatusCode::INTERNAL, + "Unable to parse FlightDescriptor")); } } break; case pb::FlightData::kDataHeaderFieldNumber: { if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->metadata)) { - return Status(StatusCode::INTERNAL, "Unable to read FlightData metadata"); + return FailSerialization(Status(StatusCode::INTERNAL, + "Unable to read FlightData metadata")); } } break; case pb::FlightData::kDataBodyFieldNumber: { if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->body)) { - return Status(StatusCode::INTERNAL, "Unable to read FlightData body"); + return FailSerialization(Status(StatusCode::INTERNAL, + "Unable to read FlightData body")); } } break; default: @@ -219,8 +241,8 @@ template <> class SerializationTraits { public: static grpc::Status Deserialize(ByteBuffer* buffer, IpcPayload* out) { - return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, - "IpcPayload deserialization not implemented"); + return FailSerialization(grpc::Status(grpc::StatusCode::UNIMPLEMENTED, + "IpcPayload deserialization not implemented")); } static grpc::Status Serialize(const IpcPayload& msg, ByteBuffer* out, @@ -256,8 +278,9 @@ class SerializationTraits { // TODO(wesm): messages over 2GB unlikely to be yet supported if (total_size > kInt32Max) { - return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, - "Cannot send record batches exceeding 2GB yet"); + return FailSerialization(grpc::Status( + grpc::StatusCode::INVALID_ARGUMENT, + "Cannot send record batches exceeding 2GB yet")); } // Allocate slice, assign to output buffer @@ -310,7 +333,4 @@ class SerializationTraits { } }; -template class grpc::ClientWriter; -template class grpc::ClientReader; - } // namespace grpc diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index bd01e9ab7d9a8..b2f24fe525325 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -229,10 +229,6 @@ class FlightServiceImpl : public FlightService::Service { return internal::ToGrpcStatus(server_->DoPut(std::move(message_reader))); } } else { - // TODO(lihalite): gRPC doesn't let us distinguish between no - // message sent, and message failed to deserialize. IMO, we - // should add logging around the Status returns in - // serialization-internal.h to make debugging such cases easier. return grpc::Status::OK; } } From 302dd334d1d730e2325951cd921f3c20ec8a50b9 Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 30 Jan 2019 09:42:22 -0500 Subject: [PATCH 12/20] Explicitly link Protobuf for Flight --- cpp/src/arrow/flight/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index 1109e0eb9da70..1cbef6cf81808 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -21,6 +21,7 @@ add_custom_target(arrow_flight) ARROW_INSTALL_ALL_HEADERS("arrow/flight") SET(ARROW_FLIGHT_STATIC_LINK_LIBS + protobuf_static grpc_grpcpp_static grpc_grpc_static grpc_gpr_static From cfa4ca5e0223000df53dfb9da761ecaee923ea15 Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 31 Jan 2019 14:04:18 -0500 Subject: [PATCH 13/20] Properly quote arguments to gRPC CMake build --- cpp/cmake_modules/ThirdpartyToolchain.cmake | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index 5ee0ddfd55914..37808e4cb0d8d 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -1406,15 +1406,15 @@ if (ARROW_WITH_GRPC) set(GRPC_CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} - -DCMAKE_PREFIX_PATH="${GRPC_PREFIX_PATH_ALT_SEP}" - "-DgRPC_CARES_PROVIDER=package" - "-DgRPC_GFLAGS_PROVIDER=package" - "-DgRPC_PROTOBUF_PROVIDER=package" - "-DgRPC_SSL_PROVIDER=package" - "-DgRPC_ZLIB_PROVIDER=package" - "-DCMAKE_CXX_FLAGS=${EP_CXX_FLAGS}" - "-DCMAKE_C_FLAGS=${EP_C_FLAGS}" - "-DCMAKE_INSTALL_PREFIX=${GRPC_PREFIX}" + -DCMAKE_PREFIX_PATH='${GRPC_PREFIX_PATH_ALT_SEP}' + '-DgRPC_CARES_PROVIDER=package' + '-DgRPC_GFLAGS_PROVIDER=package' + '-DgRPC_PROTOBUF_PROVIDER=package' + '-DgRPC_SSL_PROVIDER=package' + '-DgRPC_ZLIB_PROVIDER=package' + '-DCMAKE_CXX_FLAGS=${EP_CXX_FLAGS}' + '-DCMAKE_C_FLAGS=${EP_C_FLAGS}' + '-DCMAKE_INSTALL_PREFIX=${GRPC_PREFIX}' -DCMAKE_INSTALL_LIBDIR=lib -DBUILD_SHARED_LIBS=OFF) From 58d6936e07dbeb9d18f341a0d1b82a6a379db6f5 Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 31 Jan 2019 14:04:05 -0500 Subject: [PATCH 14/20] Enable building with non-CMake c-ares --- cpp/cmake_modules/Findc-ares.cmake | 108 ++++++++++++++++++++ cpp/cmake_modules/ThirdpartyToolchain.cmake | 7 +- 2 files changed, 109 insertions(+), 6 deletions(-) create mode 100644 cpp/cmake_modules/Findc-ares.cmake diff --git a/cpp/cmake_modules/Findc-ares.cmake b/cpp/cmake_modules/Findc-ares.cmake new file mode 100644 index 0000000000000..1366ce33fa790 --- /dev/null +++ b/cpp/cmake_modules/Findc-ares.cmake @@ -0,0 +1,108 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Tries to find c-ares headers and libraries. +# +# Usage of this module as follows: +# +# find_package(c-ares) +# +# Variables used by this module, they can change the default behaviour and need +# to be set before calling find_package: +# +# CARES_HOME - When set, this path is inspected instead of standard library +# locations as the root of the c-ares installation. +# The environment variable CARES_HOME overrides this variable. +# +# - Find CARES +# This module defines +# CARES_INCLUDE_DIR, directory containing headers +# CARES_SHARED_LIB, path to c-ares's shared library +# CARES_FOUND, whether c-ares has been found + +if( NOT "${CARES_HOME}" STREQUAL "") + file( TO_CMAKE_PATH "${CARES_HOME}" _native_path ) + list( APPEND _cares_roots ${_native_path} ) +elseif ( CARES_HOME ) + list( APPEND _cares_roots ${CARES_HOME} ) +endif() + +if (MSVC) + set(CARES_LIB_NAME cares.lib) +else () + set(CARES_LIB_NAME + ${CMAKE_SHARED_LIBRARY_PREFIX}cares${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(CARES_STATIC_LIB_NAME + ${CMAKE_STATIC_LIBRARY_PREFIX}cares${CMAKE_STATIC_LIBRARY_SUFFIX}) +endif () + +# Try the parameterized roots, if they exist +if (_cares_roots) + find_path(CARES_INCLUDE_DIR NAMES ares.h + PATHS ${_cares_roots} NO_DEFAULT_PATH + PATH_SUFFIXES "include") + find_library(CARES_SHARED_LIB + NAMES ${CARES_LIB_NAME} + PATHS ${_cares_roots} NO_DEFAULT_PATH + PATH_SUFFIXES "lib") + find_library(CARES_STATIC_LIB + NAMES ${CARES_STATIC_LIB_NAME} + PATHS ${_cares_roots} NO_DEFAULT_PATH + PATH_SUFFIXES "lib") +else () + pkg_check_modules(PKG_CARES cares) + if (PKG_CARES_FOUND) + set(CARES_INCLUDE_DIR ${PKG_CARES_INCLUDEDIR}) + find_library(CARES_SHARED_LIB + NAMES ${CARES_LIB_NAME} + PATHS ${PKG_CARES_LIBDIR} NO_DEFAULT_PATH) + else () + find_path(CARES_INCLUDE_DIR NAMES cares.h) + find_library(CARES_SHARED_LIB NAMES ${CARES_LIB_NAME}) + endif () +endif () + +if (CARES_INCLUDE_DIR AND CARES_SHARED_LIB) + set(CARES_FOUND TRUE) +else () + set(CARES_FOUND FALSE) +endif () + +if (CARES_FOUND) + if (NOT CARES_FIND_QUIETLY) + if (CARES_SHARED_LIB) + message(STATUS "Found the c-ares shared library: ${CARES_SHARED_LIB}") + endif () + endif () +else () + if (NOT CARES_FIND_QUIETLY) + set(CARES_ERR_MSG "Could not find the c-ares library. Looked in ") + if ( _cares_roots ) + set(CARES_ERR_MSG "${CARES_ERR_MSG} ${_cares_roots}.") + else () + set(CARES_ERR_MSG "${CARES_ERR_MSG} system search paths.") + endif () + if (CARES_FIND_REQUIRED) + message(FATAL_ERROR "${CARES_ERR_MSG}") + else (CARES_FIND_REQUIRED) + message(STATUS "${CARES_ERR_MSG}") + endif (CARES_FIND_REQUIRED) + endif () +endif () + +mark_as_advanced( + CARES_INCLUDE_DIR + CARES_LIBRARIES + CARES_SHARED_LIB + CARES_STATIC_LIB +) diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index 37808e4cb0d8d..9bd6cb68c70e7 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -1347,12 +1347,7 @@ if (ARROW_WITH_GRPC) BUILD_BYPRODUCTS "${CARES_STATIC_LIB}") else() set(CARES_VENDORED 0) - find_package(c-ares REQUIRED - PATHS ${CARES_HOME} - NO_DEFAULT_PATH) - if(TARGET c-ares::cares) - get_property(CARES_STATIC_LIB TARGET c-ares::cares_static PROPERTY LOCATION) - endif() + find_package(c-ares REQUIRED) endif() message(STATUS "c-ares library: ${CARES_STATIC_LIB}") From 6edf2e2bc9e2f88842b1da31244cbab4fb2aff40 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 1 Feb 2019 16:50:39 -0500 Subject: [PATCH 15/20] Introduce FlightPutWriter --- cpp/src/arrow/flight/client.cc | 41 ++++++++++++------- cpp/src/arrow/flight/client.h | 28 +++++++++---- cpp/src/arrow/flight/internal.cc | 3 +- cpp/src/arrow/flight/internal.h | 3 +- cpp/src/arrow/flight/serialization-internal.h | 26 ++++++------ .../arrow/flight/test-integration-client.cc | 10 ++--- 6 files changed, 67 insertions(+), 44 deletions(-) diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 64d9c82fa552b..26cf15bf1db94 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -115,16 +115,13 @@ class FlightClient; /// \brief A RecordBatchWriter implementation that writes to a Flight /// DoPut stream. -class FlightStreamWriter : public ipc::RecordBatchWriter { +class FlightPutWriter::FlightPutWriterImpl : public ipc::RecordBatchWriter { public: - explicit FlightStreamWriter(std::unique_ptr rpc, - const FlightDescriptor& descriptor, - const std::shared_ptr& schema, - MemoryPool* pool = default_memory_pool()) - : rpc_(std::move(rpc)), - descriptor_(descriptor), - schema_(schema), - pool_(pool) {} + explicit FlightPutWriterImpl(std::unique_ptr rpc, + const FlightDescriptor& descriptor, + const std::shared_ptr& schema, + MemoryPool* pool = default_memory_pool()) + : rpc_(std::move(rpc)), descriptor_(descriptor), schema_(schema), pool_(pool) {} Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override { IpcPayload payload; @@ -171,6 +168,22 @@ class FlightStreamWriter : public ipc::RecordBatchWriter { friend class FlightClient; }; +FlightPutWriter::~FlightPutWriter() {} + +FlightPutWriter::FlightPutWriter(std::unique_ptr impl) { + impl_ = std::move(impl); +} + +Status FlightPutWriter::WriteRecordBatch(const RecordBatch& batch, bool allow_64bit) { + return impl_->WriteRecordBatch(batch, allow_64bit); +} + +Status FlightPutWriter::Close() { return impl_->Close(); } + +void FlightPutWriter::set_memory_pool(MemoryPool* pool) { + return impl_->set_memory_pool(pool); +} + class FlightClient::FlightClientImpl { public: Status Connect(const std::string& host, int port) { @@ -277,10 +290,10 @@ class FlightClient::FlightClientImpl { } Status DoPut(const FlightDescriptor& descriptor, const std::shared_ptr& schema, - std::unique_ptr* stream) { + std::unique_ptr* stream) { std::unique_ptr rpc(new ClientRpc); - std::unique_ptr out( - new FlightStreamWriter(std::move(rpc), descriptor, schema)); + std::unique_ptr out( + new FlightPutWriter::FlightPutWriterImpl(std::move(rpc), descriptor, schema)); std::unique_ptr> write_stream( stub_->DoPut(&out->rpc_->context, &out->response)); @@ -302,7 +315,7 @@ class FlightClient::FlightClientImpl { } out->set_stream(std::move(write_stream)); - *stream = std::move(out); + *stream = std::unique_ptr(new FlightPutWriter(std::move(out))); return Status::OK(); } @@ -350,7 +363,7 @@ Status FlightClient::DoGet(const Ticket& ticket, const std::shared_ptr& Status FlightClient::DoPut(const FlightDescriptor& descriptor, const std::shared_ptr& schema, - std::unique_ptr* stream) { + std::unique_ptr* stream) { return impl_->DoPut(descriptor, schema, stream); } diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index 730d8b10f63bb..0ef96c500cfa0 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -24,6 +24,7 @@ #include #include +#include "arrow/ipc/writer.h" #include "arrow/status.h" #include "arrow/util/visibility.h" @@ -35,14 +36,10 @@ class RecordBatch; class RecordBatchReader; class Schema; -namespace ipc { - -class RecordBatchWriter; - -} // namespace ipc - namespace flight { +class FlightPutWriter; + /// \brief Client class for Arrow Flight RPC services (gRPC-based). /// API experimental for now class ARROW_EXPORT FlightClient { @@ -108,7 +105,7 @@ class ARROW_EXPORT FlightClient { /// \param[out] stream a writer to write record batches to /// \return Status Status DoPut(const FlightDescriptor& descriptor, const std::shared_ptr& schema, - std::unique_ptr* stream); + std::unique_ptr* stream); private: FlightClient(); @@ -116,5 +113,22 @@ class ARROW_EXPORT FlightClient { std::unique_ptr impl_; }; +/// \brief An interface to upload record batches to a Flight server +class ARROW_EXPORT FlightPutWriter : public ipc::RecordBatchWriter { + public: + ~FlightPutWriter(); + + Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override; + Status Close() override; + void set_memory_pool(MemoryPool* pool) override; + + private: + class FlightPutWriterImpl; + explicit FlightPutWriter(std::unique_ptr impl); + std::unique_ptr impl_; + + friend class FlightClient; +}; + } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/internal.cc b/cpp/src/arrow/flight/internal.cc index b614dd5b3ffc0..a614450e8d0a0 100644 --- a/cpp/src/arrow/flight/internal.cc +++ b/cpp/src/arrow/flight/internal.cc @@ -133,8 +133,7 @@ void ToProto(const Ticket& ticket, pb::Ticket* pb_ticket) { // FlightData -Status FromProto(const pb::FlightData& pb_data, - FlightDescriptor* descriptor, +Status FromProto(const pb::FlightData& pb_data, FlightDescriptor* descriptor, std::unique_ptr* message) { RETURN_NOT_OK(internal::FromProto(pb_data.flight_descriptor(), descriptor)); const std::string& header = pb_data.data_header(); diff --git a/cpp/src/arrow/flight/internal.h b/cpp/src/arrow/flight/internal.h index a4bafd2693df9..7f9bda138cbb1 100644 --- a/cpp/src/arrow/flight/internal.h +++ b/cpp/src/arrow/flight/internal.h @@ -57,8 +57,7 @@ Status FromProto(const pb::Result& pb_result, Result* result); Status FromProto(const pb::Criteria& pb_criteria, Criteria* criteria); Status FromProto(const pb::Location& pb_location, Location* location); Status FromProto(const pb::Ticket& pb_ticket, Ticket* ticket); -Status FromProto(const pb::FlightData& pb_data, - FlightDescriptor* descriptor, +Status FromProto(const pb::FlightData& pb_data, FlightDescriptor* descriptor, std::unique_ptr* message); Status FromProto(const pb::FlightDescriptor& pb_descr, FlightDescriptor* descr); Status FromProto(const pb::FlightEndpoint& pb_endpoint, FlightEndpoint* endpoint); diff --git a/cpp/src/arrow/flight/serialization-internal.h b/cpp/src/arrow/flight/serialization-internal.h index 4fa7238ddebfd..73d15f6bfaf07 100644 --- a/cpp/src/arrow/flight/serialization-internal.h +++ b/cpp/src/arrow/flight/serialization-internal.h @@ -168,8 +168,7 @@ inline Status FailSerialization(Status status) { inline arrow::Status FailSerialization(arrow::Status status) { if (!status.ok()) { - ARROW_LOG(WARNING) << "Error deserializing Flight message: " - << status.ToString(); + ARROW_LOG(WARNING) << "Error deserializing Flight message: " << status.ToString(); } return status; } @@ -180,9 +179,8 @@ template <> class SerializationTraits { public: static Status Serialize(const FlightData& msg, ByteBuffer** buffer, bool* own_buffer) { - return FailSerialization( - Status(StatusCode::UNIMPLEMENTED, - "internal::FlightData serialization not implemented")); + return FailSerialization(Status( + StatusCode::UNIMPLEMENTED, "internal::FlightData serialization not implemented")); } static Status Deserialize(ByteBuffer* buffer, FlightData* out) { @@ -207,20 +205,20 @@ class SerializationTraits { case pb::FlightData::kFlightDescriptorFieldNumber: { pb::FlightDescriptor pb_descriptor; if (!pb_descriptor.ParseFromCodedStream(&pb_stream)) { - return FailSerialization(Status(StatusCode::INTERNAL, - "Unable to parse FlightDescriptor")); + return FailSerialization( + Status(StatusCode::INTERNAL, "Unable to parse FlightDescriptor")); } } break; case pb::FlightData::kDataHeaderFieldNumber: { if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->metadata)) { - return FailSerialization(Status(StatusCode::INTERNAL, - "Unable to read FlightData metadata")); + return FailSerialization( + Status(StatusCode::INTERNAL, "Unable to read FlightData metadata")); } } break; case pb::FlightData::kDataBodyFieldNumber: { if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->body)) { - return FailSerialization(Status(StatusCode::INTERNAL, - "Unable to read FlightData body")); + return FailSerialization( + Status(StatusCode::INTERNAL, "Unable to read FlightData body")); } } break; default: @@ -278,9 +276,9 @@ class SerializationTraits { // TODO(wesm): messages over 2GB unlikely to be yet supported if (total_size > kInt32Max) { - return FailSerialization(grpc::Status( - grpc::StatusCode::INVALID_ARGUMENT, - "Cannot send record batches exceeding 2GB yet")); + return FailSerialization( + grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "Cannot send record batches exceeding 2GB yet")); } // Allocate slice, assign to output buffer diff --git a/cpp/src/arrow/flight/test-integration-client.cc b/cpp/src/arrow/flight/test-integration-client.cc index 3954bc8523b94..8a6414adf2b19 100644 --- a/cpp/src/arrow/flight/test-integration-client.cc +++ b/cpp/src/arrow/flight/test-integration-client.cc @@ -68,14 +68,14 @@ arrow::Status ReadToTable(std::unique_ptr& reader, - std::unique_ptr& writer) { + arrow::ipc::RecordBatchWriter& writer) { while (true) { std::shared_ptr chunk; RETURN_NOT_OK(reader->ReadNext(&chunk)); if (chunk == nullptr) break; - RETURN_NOT_OK(writer->WriteRecordBatch(*chunk)); + RETURN_NOT_OK(writer.WriteRecordBatch(*chunk)); } - return writer->Close(); + return writer.Close(); } /// \brief Helper to read a flight into a Table. @@ -114,11 +114,11 @@ int main(int argc, char** argv) { std::shared_ptr original_data; ABORT_NOT_OK(ReadToTable(reader, &original_data)); - std::unique_ptr write_stream; + std::unique_ptr write_stream; ABORT_NOT_OK(client->DoPut(descr, reader->schema(), &write_stream)); std::unique_ptr table_reader( new arrow::TableBatchReader(*original_data)); - ABORT_NOT_OK(CopyReaderToWriter(table_reader, write_stream)); + ABORT_NOT_OK(CopyReaderToWriter(table_reader, *write_stream)); // 2. Get the ticket for the data. std::unique_ptr info; From 21b315a5a8363b149af4248aae68dc79047351fd Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 5 Feb 2019 09:40:52 -0500 Subject: [PATCH 16/20] Hide FlightPutWriter from public interface for now --- cpp/src/arrow/flight/client.cc | 7 ++++--- cpp/src/arrow/flight/client.h | 2 +- cpp/src/arrow/flight/test-integration-client.cc | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 26cf15bf1db94..77c12eee90a5f 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -290,7 +290,7 @@ class FlightClient::FlightClientImpl { } Status DoPut(const FlightDescriptor& descriptor, const std::shared_ptr& schema, - std::unique_ptr* stream) { + std::unique_ptr* stream) { std::unique_ptr rpc(new ClientRpc); std::unique_ptr out( new FlightPutWriter::FlightPutWriterImpl(std::move(rpc), descriptor, schema)); @@ -315,7 +315,8 @@ class FlightClient::FlightClientImpl { } out->set_stream(std::move(write_stream)); - *stream = std::unique_ptr(new FlightPutWriter(std::move(out))); + *stream = + std::unique_ptr(new FlightPutWriter(std::move(out))); return Status::OK(); } @@ -363,7 +364,7 @@ Status FlightClient::DoGet(const Ticket& ticket, const std::shared_ptr& Status FlightClient::DoPut(const FlightDescriptor& descriptor, const std::shared_ptr& schema, - std::unique_ptr* stream) { + std::unique_ptr* stream) { return impl_->DoPut(descriptor, schema, stream); } diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index 0ef96c500cfa0..ef960417b024a 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -105,7 +105,7 @@ class ARROW_EXPORT FlightClient { /// \param[out] stream a writer to write record batches to /// \return Status Status DoPut(const FlightDescriptor& descriptor, const std::shared_ptr& schema, - std::unique_ptr* stream); + std::unique_ptr* stream); private: FlightClient(); diff --git a/cpp/src/arrow/flight/test-integration-client.cc b/cpp/src/arrow/flight/test-integration-client.cc index 8a6414adf2b19..62522833f4ba3 100644 --- a/cpp/src/arrow/flight/test-integration-client.cc +++ b/cpp/src/arrow/flight/test-integration-client.cc @@ -114,7 +114,7 @@ int main(int argc, char** argv) { std::shared_ptr original_data; ABORT_NOT_OK(ReadToTable(reader, &original_data)); - std::unique_ptr write_stream; + std::unique_ptr write_stream; ABORT_NOT_OK(client->DoPut(descr, reader->schema(), &write_stream)); std::unique_ptr table_reader( new arrow::TableBatchReader(*original_data)); From 1f816e8c3184ad9f3342eeb4be5e85ff875768cd Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 5 Feb 2019 10:18:31 -0500 Subject: [PATCH 17/20] Move serialization helpers out of gRPC namespace --- .../arrow/flight/serialization-internal.cc | 8 ++++-- cpp/src/arrow/flight/serialization-internal.h | 27 +++++++++++++------ 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/cpp/src/arrow/flight/serialization-internal.cc b/cpp/src/arrow/flight/serialization-internal.cc index f3dbba47255f4..194a7b5bc0c30 100644 --- a/cpp/src/arrow/flight/serialization-internal.cc +++ b/cpp/src/arrow/flight/serialization-internal.cc @@ -17,7 +17,9 @@ #include "arrow/flight/serialization-internal.h" -namespace grpc { +namespace arrow { +namespace flight { +namespace internal { bool ReadBytesZeroCopy(const std::shared_ptr& source_data, CodedInputStream* input, std::shared_ptr* out) { @@ -30,4 +32,6 @@ bool ReadBytesZeroCopy(const std::shared_ptr& source_data, return input->Skip(static_cast(length)); } -} // namespace grpc +} // namespace internal +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/serialization-internal.h b/cpp/src/arrow/flight/serialization-internal.h index 73d15f6bfaf07..d4254d606d40f 100644 --- a/cpp/src/arrow/flight/serialization-internal.h +++ b/cpp/src/arrow/flight/serialization-internal.h @@ -61,17 +61,11 @@ struct FlightData { std::shared_ptr body; }; -} // namespace flight -} // namespace arrow +namespace internal { -namespace grpc { - -using google::protobuf::internal::WireFormatLite; using google::protobuf::io::CodedInputStream; using google::protobuf::io::CodedOutputStream; -using arrow::flight::FlightData; - // More efficient writing of FlightData to gRPC output buffer // Implementation of ZeroCopyOutputStream that writes to a fixed-size buffer class FixedSizeProtoWriter : public ::google::protobuf::io::ZeroCopyOutputStream { @@ -116,7 +110,8 @@ class GrpcBuffer : public arrow::MutableBuffer { grpc_slice_unref(slice_); } - static arrow::Status Wrap(ByteBuffer* cpp_buf, std::shared_ptr* out) { + static arrow::Status Wrap(grpc::ByteBuffer* cpp_buf, + std::shared_ptr* out) { // These types are guaranteed by static assertions in gRPC to have the same // in-memory representation @@ -156,6 +151,22 @@ class GrpcBuffer : public arrow::MutableBuffer { grpc_slice slice_; }; +} // namespace internal + +} // namespace flight +} // namespace arrow + +namespace grpc { + +using arrow::flight::FlightData; +using arrow::flight::internal::FixedSizeProtoWriter; +using arrow::flight::internal::GrpcBuffer; +using arrow::flight::internal::ReadBytesZeroCopy; + +using google::protobuf::internal::WireFormatLite; +using google::protobuf::io::CodedInputStream; +using google::protobuf::io::CodedOutputStream; + // Helper to log status code, as gRPC doesn't expose why // (de)serialization fails inline Status FailSerialization(Status status) { From cd567820300912b3d784c06d05b1440a4e4ffb87 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 5 Feb 2019 09:37:04 -0500 Subject: [PATCH 18/20] Warn about undefined behavior in Flight source --- cpp/src/arrow/flight/client.cc | 2 ++ cpp/src/arrow/flight/server.cc | 7 ++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 77c12eee90a5f..af4af60f6580b 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -75,6 +75,7 @@ class FlightStreamReader : public RecordBatchReader { } // For customizing read path for better memory/serialization efficiency + // XXX this cast is undefined behavior auto custom_reader = reinterpret_cast*>(stream_.get()); // Explicitly specify the override to invoke - otherwise compiler @@ -126,6 +127,7 @@ class FlightPutWriter::FlightPutWriterImpl : public ipc::RecordBatchWriter { Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override { IpcPayload payload; RETURN_NOT_OK(ipc::internal::GetRecordBatchPayload(batch, pool_, &payload)); + // XXX this cast is undefined behavior auto custom_writer = reinterpret_cast*>(writer_.get()); // Explicitly specify the override to invoke - otherwise compiler // may invoke through vtable (not updated by reinterpret_cast) diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index b2f24fe525325..0788952c2d340 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -73,6 +73,7 @@ class FlightMessageReaderImpl : public FlightMessageReader { return Status::OK(); } + // XXX this cast is undefined behavior auto custom_reader = reinterpret_cast*>(reader_); FlightData data; @@ -184,6 +185,7 @@ class FlightServiceImpl : public FlightService::Service { GRPC_RETURN_NOT_OK(server_->DoGet(ticket, &data_stream)); // Requires ServerWriter customization in grpc_customizations.h + // XXX this cast is undefined behavior auto custom_writer = reinterpret_cast*>(writer); // Write the schema as the first message in the stream @@ -192,7 +194,10 @@ class FlightServiceImpl : public FlightService::Service { ipc::DictionaryMemo dictionary_memo; GRPC_RETURN_NOT_OK(ipc::internal::GetSchemaPayload( *data_stream->schema(), pool, &dictionary_memo, &schema_payload)); - custom_writer->Write(schema_payload, grpc::WriteOptions()); + // Explicitly specify the override to invoke - otherwise compiler + // may invoke through vtable (not updated by reinterpret_cast) + custom_writer->grpc::ServerWriter::Write(schema_payload, + grpc::WriteOptions()); while (true) { IpcPayload payload; From f32c0b25d562b5e8a9188c844b37fac086d629c0 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 5 Feb 2019 09:44:39 -0500 Subject: [PATCH 19/20] Indicate error to client in DoPut if no message sent --- cpp/src/arrow/flight/client.cc | 11 ++++++++--- cpp/src/arrow/flight/server.cc | 4 +++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index af4af60f6580b..a58c2b5933225 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -133,8 +133,10 @@ class FlightPutWriter::FlightPutWriterImpl : public ipc::RecordBatchWriter { // may invoke through vtable (not updated by reinterpret_cast) if (!custom_writer->grpc::ClientWriter::Write(payload, grpc::WriteOptions())) { - // Stream ended? - return Status::UnknownError("Could not write record batch to stream"); + std::stringstream ss; + ss << "Could not write record batch to stream: " + << rpc_->context.debug_error_string(); + return Status::IOError(ss.str()); } return Status::OK(); } @@ -313,7 +315,10 @@ class FlightClient::FlightClientImpl { descriptor_message.set_data_header(header_buf->ToString()); if (!write_stream->Write(descriptor_message, grpc::WriteOptions())) { - return Status::UnknownError("Could not write initial message to stream"); + std::stringstream ss; + ss << "Could not write descriptor and schema to stream: " + << rpc->context.debug_error_string(); + return Status::IOError(ss.str()); } out->set_stream(std::move(write_stream)); diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index 0788952c2d340..ac5b53532866f 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -234,7 +234,9 @@ class FlightServiceImpl : public FlightService::Service { return internal::ToGrpcStatus(server_->DoPut(std::move(message_reader))); } } else { - return grpc::Status::OK; + return internal::ToGrpcStatus( + Status(StatusCode::Invalid, + "Client provided malformed message or did not provide message")); } } From 13fb29afb3b33b71ca633a7643dee51d6dc0d128 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 5 Feb 2019 09:53:42 -0500 Subject: [PATCH 20/20] Document why VectorUnloader must align batches in Flight --- .../src/main/java/org/apache/arrow/flight/FlightClient.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java index 7fcfd52008f98..bd126b5ea203c 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java @@ -127,7 +127,9 @@ public ClientStreamListener startPut(FlightDescriptor descriptor, VectorSchemaRo // send the schema to start. ArrowMessage message = new ArrowMessage(descriptor.toProtocol(), root.getSchema()); observer.onNext(message); - return new PutObserver(new VectorUnloader(root, true, true), observer, resultObserver.getFuture()); + return new PutObserver(new VectorUnloader( + root, true /* include # of nulls in vectors */, true /* must align buffers to be C++-compatible */), + observer, resultObserver.getFuture()); } public FlightInfo getInfo(FlightDescriptor descriptor) {