From 905ef38fa3c90c520642d0ae14bf3e6d7b3aa608 Mon Sep 17 00:00:00 2001 From: David Li Date: Sat, 26 Jan 2019 13:45:15 -0500 Subject: [PATCH] Implement C++ Flight DoPut --- cpp/src/arrow/flight/CMakeLists.txt | 1 + cpp/src/arrow/flight/client.cc | 262 ++++++--------- cpp/src/arrow/flight/client.h | 16 +- .../arrow/flight/serialization-internal.cc | 33 ++ cpp/src/arrow/flight/serialization-internal.h | 316 ++++++++++++++++++ cpp/src/arrow/flight/server.cc | 213 +++++------- cpp/src/arrow/flight/server.h | 13 +- .../arrow/flight/test-integration-client.cc | 54 ++- .../arrow/flight/test-integration-server.cc | 102 +++--- cpp/src/arrow/flight/types.h | 6 - integration/integration_test.py | 31 +- .../arrow/flight/example/InMemoryStore.java | 1 + .../integration/IntegrationTestClient.java | 64 ++-- .../integration/IntegrationTestServer.java | 155 ++------- 14 files changed, 720 insertions(+), 547 deletions(-) create mode 100644 cpp/src/arrow/flight/serialization-internal.cc create mode 100644 cpp/src/arrow/flight/serialization-internal.h diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index b8b4d8d336365..1109e0eb9da70 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -69,6 +69,7 @@ set(ARROW_FLIGHT_SRCS Flight.pb.cc Flight.grpc.pb.cc internal.cc + serialization-internal.cc server.cc types.cc ) diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index e25c1875d669f..99f88d08a843e 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -22,12 +22,12 @@ #include #include -#include "google/protobuf/io/coded_stream.h" -#include "google/protobuf/wire_format_lite.h" -#include "grpc/byte_buffer_reader.h" #include "grpcpp/grpcpp.h" +#include "arrow/ipc/dictionary.h" +#include "arrow/ipc/metadata-internal.h" #include "arrow/ipc/reader.h" +#include "arrow/ipc/writer.h" #include "arrow/record_batch.h" #include "arrow/status.h" #include "arrow/type.h" @@ -36,161 +36,11 @@ #include "arrow/flight/Flight.grpc.pb.h" #include "arrow/flight/Flight.pb.h" #include "arrow/flight/internal.h" +#include "arrow/flight/serialization-internal.h" -namespace pb = arrow::flight::protocol; - -namespace arrow { -namespace flight { - -/// Internal, not user-visible type used for memory-efficient reads from gRPC -/// stream -struct FlightData { - /// Used only for puts, may be null - std::unique_ptr descriptor; - - /// Non-length-prefixed Message header as described in format/Message.fbs - std::shared_ptr metadata; - - /// Message body - std::shared_ptr body; -}; - -} // namespace flight -} // namespace arrow - -namespace grpc { - -// Customizations to gRPC for more efficient deserialization of FlightData - -using google::protobuf::internal::WireFormatLite; -using google::protobuf::io::CodedInputStream; - -using arrow::flight::FlightData; - -bool ReadBytesZeroCopy(const std::shared_ptr& source_data, - CodedInputStream* input, std::shared_ptr* out) { - uint32_t length; - if (!input->ReadVarint32(&length)) { - return false; - } - *out = arrow::SliceBuffer(source_data, input->CurrentPosition(), - static_cast(length)); - return input->Skip(static_cast(length)); -} +using arrow::ipc::internal::IpcPayload; -// Internal wrapper for gRPC ByteBuffer so its memory can be exposed to Arrow -// consumers with zero-copy -class GrpcBuffer : public arrow::MutableBuffer { - public: - GrpcBuffer(grpc_slice slice, bool incref) - : MutableBuffer(GRPC_SLICE_START_PTR(slice), - static_cast(GRPC_SLICE_LENGTH(slice))), - slice_(incref ? grpc_slice_ref(slice) : slice) {} - - ~GrpcBuffer() override { - // Decref slice - grpc_slice_unref(slice_); - } - - static arrow::Status Wrap(ByteBuffer* cpp_buf, std::shared_ptr* out) { - // These types are guaranteed by static assertions in gRPC to have the same - // in-memory representation - - auto buffer = *reinterpret_cast(cpp_buf); - - // This part below is based on the Flatbuffers gRPC SerializationTraits in - // flatbuffers/grpc.h - - // Check if this is a single uncompressed slice. - if ((buffer->type == GRPC_BB_RAW) && - (buffer->data.raw.compression == GRPC_COMPRESS_NONE) && - (buffer->data.raw.slice_buffer.count == 1)) { - // If it is, then we can reference the `grpc_slice` directly. - grpc_slice slice = buffer->data.raw.slice_buffer.slices[0]; - - // Increment reference count so this memory remains valid - *out = std::make_shared(slice, true); - } else { - // Otherwise, we need to use `grpc_byte_buffer_reader_readall` to read - // `buffer` into a single contiguous `grpc_slice`. The gRPC reader gives - // us back a new slice with the refcount already incremented. - grpc_byte_buffer_reader reader; - if (!grpc_byte_buffer_reader_init(&reader, buffer)) { - return arrow::Status::IOError("Internal gRPC error reading from ByteBuffer"); - } - grpc_slice slice = grpc_byte_buffer_reader_readall(&reader); - grpc_byte_buffer_reader_destroy(&reader); - - // Steal the slice reference - *out = std::make_shared(slice, false); - } - - return arrow::Status::OK(); - } - - private: - grpc_slice slice_; -}; - -// Read internal::FlightData from grpc::ByteBuffer containing FlightData -// protobuf without copying -template <> -class SerializationTraits { - public: - static Status Serialize(const FlightData& msg, ByteBuffer** buffer, bool* own_buffer) { - return Status(StatusCode::UNIMPLEMENTED, - "internal::FlightData serialization not implemented"); - } - - static Status Deserialize(ByteBuffer* buffer, FlightData* out) { - if (!buffer) { - return Status(StatusCode::INTERNAL, "No payload"); - } - - std::shared_ptr wrapped_buffer; - GRPC_RETURN_NOT_OK(GrpcBuffer::Wrap(buffer, &wrapped_buffer)); - - auto buffer_length = static_cast(wrapped_buffer->size()); - CodedInputStream pb_stream(wrapped_buffer->data(), buffer_length); - - // TODO(wesm): The 2-parameter version of this function is deprecated - pb_stream.SetTotalBytesLimit(buffer_length, -1 /* no threshold */); - - // This is the bytes remaining when using CodedInputStream like this - while (pb_stream.BytesUntilTotalBytesLimit()) { - const uint32_t tag = pb_stream.ReadTag(); - const int field_number = WireFormatLite::GetTagFieldNumber(tag); - switch (field_number) { - case pb::FlightData::kFlightDescriptorFieldNumber: { - pb::FlightDescriptor pb_descriptor; - if (!pb_descriptor.ParseFromCodedStream(&pb_stream)) { - return Status(StatusCode::INTERNAL, "Unable to parse FlightDescriptor"); - } - } break; - case pb::FlightData::kDataHeaderFieldNumber: { - if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->metadata)) { - return Status(StatusCode::INTERNAL, "Unable to read FlightData metadata"); - } - } break; - case pb::FlightData::kDataBodyFieldNumber: { - if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->body)) { - return Status(StatusCode::INTERNAL, "Unable to read FlightData body"); - } - } break; - default: - DCHECK(false) << "cannot happen"; - } - } - buffer->Clear(); - - // TODO(wesm): Where and when should we verify that the FlightData is not - // malformed or missing components? - - return Status::OK; - } -}; - -} // namespace grpc +namespace pb = arrow::flight::protocol; namespace arrow { namespace flight { @@ -227,7 +77,9 @@ class FlightStreamReader : public RecordBatchReader { // For customizing read path for better memory/serialization efficiency auto custom_reader = reinterpret_cast*>(stream_.get()); - if (custom_reader->Read(&data)) { + // Explicitly specify the override to invoke - otherwise compiler + // may invoke through vtable (not updated by reinterpret_cast) + if (custom_reader->grpc::ClientReader::Read(&data)) { std::unique_ptr message; // Validate IPC message @@ -259,6 +111,65 @@ class FlightStreamReader : public RecordBatchReader { std::unique_ptr> stream_; }; +class FlightClient; + +/// \brief A RecordBatchWriter implementation that writes to a Flight +/// DoPut stream. +class FlightStreamWriter : public ipc::RecordBatchWriter { + public: + explicit FlightStreamWriter(std::unique_ptr&& rpc, + const FlightDescriptor& descriptor, + const std::shared_ptr& schema) + : rpc_{std::move(rpc)}, + descriptor_{descriptor}, + schema_{schema}, + pool_{default_memory_pool()} {}; + + Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override { + IpcPayload payload; + RETURN_NOT_OK(ipc::internal::GetRecordBatchPayload(batch, pool_, &payload)); + auto custom_writer = reinterpret_cast*>(writer_.get()); + // Explicitly specify the override to invoke - otherwise compiler + // may invoke through vtable (not updated by reinterpret_cast) + if (!custom_writer->grpc::ClientWriter::Write(payload, + grpc::WriteOptions())) { + // Stream ended? + return Status::UnknownError("Could not write record batch to stream"); + } + return Status::OK(); + } + + Status Close() override { + bool finished_writes = writer_->WritesDone(); + RETURN_NOT_OK(internal::FromGrpcStatus(writer_->Finish())); + if (!finished_writes) { + return Status::UnknownError( + "Could not finish writing record batches before closing"); + } + return Status::OK(); + } + + void set_memory_pool(MemoryPool* pool) override { pool_ = pool; } + + private: + /// \brief Set the gRPC writer backing this Flight stream. + /// \param [in] writer the gRPC writer + void set_stream(std::unique_ptr>&& writer) { + writer_ = std::move(writer); + } + + // TODO: there isn't a way to access this as a user. + protocol::PutResult response; + std::unique_ptr rpc_; + FlightDescriptor descriptor_; + std::shared_ptr schema_; + std::unique_ptr> writer_; + MemoryPool* pool_; + + // We need to reference some fields + friend class FlightClient; +}; + class FlightClient::FlightClientImpl { public: Status Connect(const std::string& host, int port) { @@ -364,8 +275,34 @@ class FlightClient::FlightClientImpl { return Status::OK(); } - Status DoPut(std::unique_ptr* stream) { - return Status::NotImplemented("DoPut"); + Status DoPut(const FlightDescriptor& descriptor, const std::shared_ptr& schema, + std::unique_ptr* stream) { + std::unique_ptr rpc(new ClientRpc); + std::unique_ptr out( + new FlightStreamWriter(std::move(rpc), descriptor, schema)); + std::unique_ptr> write_stream( + stub_->DoPut(&out->rpc_->context, &out->response)); + + // First write the descriptor and schema to the stream. + pb::FlightData descriptor_message; + RETURN_NOT_OK( + internal::ToProto(descriptor, descriptor_message.mutable_flight_descriptor())); + + std::shared_ptr header_buf; + RETURN_NOT_OK(Buffer::FromString("", &header_buf)); + ipc::DictionaryMemo dictionary_memo; + RETURN_NOT_OK(ipc::SerializeSchema(*schema, out->pool_, &header_buf)); + RETURN_NOT_OK( + ipc::internal::WriteSchemaMessage(*schema, &dictionary_memo, &header_buf)); + descriptor_message.set_data_header(header_buf->ToString()); + + if (!write_stream->Write(descriptor_message, grpc::WriteOptions())) { + return Status::UnknownError("Could not write initial message to stream"); + } + + out->set_stream(std::move(write_stream)); + *stream = std::move(out); + return Status::OK(); } private: @@ -410,9 +347,10 @@ Status FlightClient::DoGet(const Ticket& ticket, const std::shared_ptr& return impl_->DoGet(ticket, schema, stream); } -Status FlightClient::DoPut(const Schema& schema, - std::unique_ptr* stream) { - return Status::NotImplemented("DoPut"); +Status FlightClient::DoPut(const FlightDescriptor& descriptor, + const std::shared_ptr& schema, + std::unique_ptr* stream) { + return impl_->DoPut(descriptor, schema, stream); } } // namespace flight diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index 53bb1755b2995..e548f7c76e848 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -35,6 +35,12 @@ class RecordBatch; class RecordBatchReader; class Schema; +namespace ipc { + +class RecordBatchWriter; + +} // namespace ipc + namespace flight { /// \brief Client class for Arrow Flight RPC services (gRPC-based). @@ -94,12 +100,12 @@ class ARROW_EXPORT FlightClient { Status DoGet(const Ticket& ticket, const std::shared_ptr& schema, std::unique_ptr* stream); - /// \brief Initiate DoPut RPC, returns FlightPutWriter interface to - /// write. Not yet implemented - /// \param[in] schema the schema of the stream data - /// \param[out] stream the created stream to write record batches to + /// \brief Upload data to a Flight described by the given descriptor. + /// \param[in] descriptor the descriptor of the stream + /// \param[out] stream a writer to write record batches to /// \return Status - Status DoPut(const Schema& schema, std::unique_ptr* stream); + Status DoPut(const FlightDescriptor& descriptor, const std::shared_ptr& schema, + std::unique_ptr* stream); private: FlightClient(); diff --git a/cpp/src/arrow/flight/serialization-internal.cc b/cpp/src/arrow/flight/serialization-internal.cc new file mode 100644 index 0000000000000..f3dbba47255f4 --- /dev/null +++ b/cpp/src/arrow/flight/serialization-internal.cc @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/serialization-internal.h" + +namespace grpc { + +bool ReadBytesZeroCopy(const std::shared_ptr& source_data, + CodedInputStream* input, std::shared_ptr* out) { + uint32_t length; + if (!input->ReadVarint32(&length)) { + return false; + } + *out = arrow::SliceBuffer(source_data, input->CurrentPosition(), + static_cast(length)); + return input->Skip(static_cast(length)); +} + +} // namespace grpc diff --git a/cpp/src/arrow/flight/serialization-internal.h b/cpp/src/arrow/flight/serialization-internal.h new file mode 100644 index 0000000000000..e412e247012f1 --- /dev/null +++ b/cpp/src/arrow/flight/serialization-internal.h @@ -0,0 +1,316 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// (De)serialization utilities that hook into gRPC, efficiently +// handling Arrow-encoded data in a gRPC call. + +#pragma once + +#include +#include + +#include "google/protobuf/io/coded_stream.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/wire_format_lite.h" +#include "grpc/byte_buffer_reader.h" +#include "grpcpp/grpcpp.h" + +#include "arrow/ipc/writer.h" +#include "arrow/record_batch.h" +#include "arrow/status.h" +#include "arrow/util/logging.h" + +#include "arrow/flight/Flight.grpc.pb.h" +#include "arrow/flight/Flight.pb.h" +#include "arrow/flight/internal.h" +#include "arrow/flight/types.h" + +namespace pb = arrow::flight::protocol; + +using arrow::ipc::internal::IpcPayload; + +constexpr int64_t kInt32Max = std::numeric_limits::max(); + +namespace arrow { +namespace flight { + +/// Internal, not user-visible type used for memory-efficient reads from gRPC +/// stream +struct FlightData { + /// Used only for puts, may be null + std::unique_ptr descriptor; + + /// Non-length-prefixed Message header as described in format/Message.fbs + std::shared_ptr metadata; + + /// Message body + std::shared_ptr body; +}; + +} // namespace flight +} // namespace arrow + +namespace grpc { + +using google::protobuf::internal::WireFormatLite; +using google::protobuf::io::CodedInputStream; +using google::protobuf::io::CodedOutputStream; + +using arrow::flight::FlightData; + +// More efficient writing of FlightData to gRPC output buffer +// Implementation of ZeroCopyOutputStream that writes to a fixed-size buffer +class FixedSizeProtoWriter : public ::google::protobuf::io::ZeroCopyOutputStream { + public: + explicit FixedSizeProtoWriter(grpc_slice slice) + : slice_(slice), + bytes_written_(0), + total_size_(static_cast(GRPC_SLICE_LENGTH(slice))) {} + + bool Next(void** data, int* size) override { + // Consume the whole slice + *data = GRPC_SLICE_START_PTR(slice_) + bytes_written_; + *size = total_size_ - bytes_written_; + bytes_written_ = total_size_; + return true; + } + + void BackUp(int count) override { bytes_written_ -= count; } + + int64_t ByteCount() const override { return bytes_written_; } + + private: + grpc_slice slice_; + int bytes_written_; + int total_size_; +}; + +bool ReadBytesZeroCopy(const std::shared_ptr& source_data, + CodedInputStream* input, std::shared_ptr* out); + +// Internal wrapper for gRPC ByteBuffer so its memory can be exposed to Arrow +// consumers with zero-copy +class GrpcBuffer : public arrow::MutableBuffer { + public: + GrpcBuffer(grpc_slice slice, bool incref) + : MutableBuffer(GRPC_SLICE_START_PTR(slice), + static_cast(GRPC_SLICE_LENGTH(slice))), + slice_(incref ? grpc_slice_ref(slice) : slice) {} + + ~GrpcBuffer() override { + // Decref slice + grpc_slice_unref(slice_); + } + + static arrow::Status Wrap(ByteBuffer* cpp_buf, std::shared_ptr* out) { + // These types are guaranteed by static assertions in gRPC to have the same + // in-memory representation + + auto buffer = *reinterpret_cast(cpp_buf); + + // This part below is based on the Flatbuffers gRPC SerializationTraits in + // flatbuffers/grpc.h + + // Check if this is a single uncompressed slice. + if ((buffer->type == GRPC_BB_RAW) && + (buffer->data.raw.compression == GRPC_COMPRESS_NONE) && + (buffer->data.raw.slice_buffer.count == 1)) { + // If it is, then we can reference the `grpc_slice` directly. + grpc_slice slice = buffer->data.raw.slice_buffer.slices[0]; + + // Increment reference count so this memory remains valid + *out = std::make_shared(slice, true); + } else { + // Otherwise, we need to use `grpc_byte_buffer_reader_readall` to read + // `buffer` into a single contiguous `grpc_slice`. The gRPC reader gives + // us back a new slice with the refcount already incremented. + grpc_byte_buffer_reader reader; + if (!grpc_byte_buffer_reader_init(&reader, buffer)) { + return arrow::Status::IOError("Internal gRPC error reading from ByteBuffer"); + } + grpc_slice slice = grpc_byte_buffer_reader_readall(&reader); + grpc_byte_buffer_reader_destroy(&reader); + + // Steal the slice reference + *out = std::make_shared(slice, false); + } + + return arrow::Status::OK(); + } + + private: + grpc_slice slice_; +}; + +// Read internal::FlightData from grpc::ByteBuffer containing FlightData +// protobuf without copying +template <> +class SerializationTraits { + public: + static Status Serialize(const FlightData& msg, ByteBuffer** buffer, bool* own_buffer) { + return Status(StatusCode::UNIMPLEMENTED, + "internal::FlightData serialization not implemented"); + } + + static Status Deserialize(ByteBuffer* buffer, FlightData* out) { + if (!buffer) { + return Status(StatusCode::INTERNAL, "No payload"); + } + + std::shared_ptr wrapped_buffer; + GRPC_RETURN_NOT_OK(GrpcBuffer::Wrap(buffer, &wrapped_buffer)); + + auto buffer_length = static_cast(wrapped_buffer->size()); + CodedInputStream pb_stream(wrapped_buffer->data(), buffer_length); + + // TODO(wesm): The 2-parameter version of this function is deprecated + pb_stream.SetTotalBytesLimit(buffer_length, -1 /* no threshold */); + + // This is the bytes remaining when using CodedInputStream like this + while (pb_stream.BytesUntilTotalBytesLimit()) { + const uint32_t tag = pb_stream.ReadTag(); + const int field_number = WireFormatLite::GetTagFieldNumber(tag); + switch (field_number) { + case pb::FlightData::kFlightDescriptorFieldNumber: { + pb::FlightDescriptor pb_descriptor; + if (!pb_descriptor.ParseFromCodedStream(&pb_stream)) { + return Status(StatusCode::INTERNAL, "Unable to parse FlightDescriptor"); + } + } break; + case pb::FlightData::kDataHeaderFieldNumber: { + if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->metadata)) { + return Status(StatusCode::INTERNAL, "Unable to read FlightData metadata"); + } + } break; + case pb::FlightData::kDataBodyFieldNumber: { + if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->body)) { + return Status(StatusCode::INTERNAL, "Unable to read FlightData body"); + } + } break; + default: + DCHECK(false) << "cannot happen"; + } + } + buffer->Clear(); + + // TODO(wesm): Where and when should we verify that the FlightData is not + // malformed or missing components? + + return Status::OK; + } +}; + +// Write FlightData to a grpc::ByteBuffer without extra copying +template <> +class SerializationTraits { + public: + static grpc::Status Deserialize(ByteBuffer* buffer, IpcPayload* out) { + return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, + "IpcPayload deserialization not implemented"); + } + + static grpc::Status Serialize(const IpcPayload& msg, ByteBuffer* out, + bool* own_buffer) { + size_t total_size = 0; + + DCHECK_LT(msg.metadata->size(), kInt32Max); + const int32_t metadata_size = static_cast(msg.metadata->size()); + + // 1 byte for metadata tag + total_size += 1 + WireFormatLite::LengthDelimitedSize(metadata_size); + + int64_t body_size = 0; + for (const auto& buffer : msg.body_buffers) { + // Buffer may be null when the row length is zero, or when all + // entries are invalid. + if (!buffer) continue; + + body_size += buffer->size(); + + const int64_t remainder = buffer->size() % 8; + if (remainder) { + body_size += 8 - remainder; + } + } + + // 2 bytes for body tag + // Only written when there are body buffers + if (msg.body_length > 0) { + total_size += + 2 + WireFormatLite::LengthDelimitedSize(static_cast(body_size)); + } + + // TODO(wesm): messages over 2GB unlikely to be yet supported + if (total_size > kInt32Max) { + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "Cannot send record batches exceeding 2GB yet"); + } + + // Allocate slice, assign to output buffer + grpc::Slice slice(total_size); + + // XXX(wesm): for debugging + // std::cout << "Writing record batch with total size " << total_size << std::endl; + + FixedSizeProtoWriter writer(*reinterpret_cast(&slice)); + CodedOutputStream pb_stream(&writer); + + // Write header + WireFormatLite::WriteTag(pb::FlightData::kDataHeaderFieldNumber, + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &pb_stream); + pb_stream.WriteVarint32(metadata_size); + pb_stream.WriteRawMaybeAliased(msg.metadata->data(), + static_cast(msg.metadata->size())); + + // Don't write tag if there are no body buffers + if (msg.body_length > 0) { + // Write body + WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber, + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &pb_stream); + pb_stream.WriteVarint32(static_cast(body_size)); + + constexpr uint8_t kPaddingBytes[8] = {0}; + + for (const auto& buffer : msg.body_buffers) { + // Buffer may be null when the row length is zero, or when all + // entries are invalid. + if (!buffer) continue; + + pb_stream.WriteRawMaybeAliased(buffer->data(), static_cast(buffer->size())); + + // Write padding if not multiple of 8 + const int remainder = static_cast(buffer->size() % 8); + if (remainder) { + pb_stream.WriteRawMaybeAliased(kPaddingBytes, 8 - remainder); + } + } + } + + DCHECK_EQ(static_cast(total_size), pb_stream.ByteCount()); + + // Hand off the slice to the returned ByteBuffer + grpc::ByteBuffer tmp(&slice, 1); + out->Swap(&tmp); + *own_buffer = true; + return grpc::Status::OK; + } +}; + +template class grpc::ClientWriter; +template class grpc::ClientReader; + +} // namespace grpc diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index 018c079501f2f..77f665356f1db 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -18,15 +18,12 @@ #include "arrow/flight/server.h" #include -#include #include #include -#include "google/protobuf/io/coded_stream.h" -#include "google/protobuf/io/zero_copy_stream.h" -#include "google/protobuf/wire_format_lite.h" #include "grpcpp/grpcpp.h" +#include "arrow/ipc/reader.h" #include "arrow/ipc/writer.h" #include "arrow/record_batch.h" #include "arrow/status.h" @@ -35,6 +32,7 @@ #include "arrow/flight/Flight.grpc.pb.h" #include "arrow/flight/Flight.pb.h" #include "arrow/flight/internal.h" +#include "arrow/flight/serialization-internal.h" #include "arrow/flight/types.h" using FlightService = arrow::flight::protocol::FlightService; @@ -47,145 +45,63 @@ using ServerWriter = grpc::ServerWriter; namespace pb = arrow::flight::protocol; -constexpr int64_t kInt32Max = std::numeric_limits::max(); - -namespace grpc { - -using google::protobuf::internal::WireFormatLite; -using google::protobuf::io::CodedOutputStream; +namespace arrow { +namespace flight { -// More efficient writing of FlightData to gRPC output buffer -// Implementation of ZeroCopyOutputStream that writes to a fixed-size buffer -class FixedSizeProtoWriter : public ::google::protobuf::io::ZeroCopyOutputStream { - public: - explicit FixedSizeProtoWriter(grpc_slice slice) - : slice_(slice), - bytes_written_(0), - total_size_(static_cast(GRPC_SLICE_LENGTH(slice))) {} - - bool Next(void** data, int* size) override { - // Consume the whole slice - *data = GRPC_SLICE_START_PTR(slice_) + bytes_written_; - *size = total_size_ - bytes_written_; - bytes_written_ = total_size_; - return true; +#define CHECK_ARG_NOT_NULL(VAL, MESSAGE) \ + if (VAL == nullptr) { \ + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, MESSAGE); \ } - void BackUp(int count) override { bytes_written_ -= count; } - - int64_t ByteCount() const override { return bytes_written_; } - - private: - grpc_slice slice_; - int bytes_written_; - int total_size_; -}; - -// Write FlightData to a grpc::ByteBuffer without extra copying -template <> -class SerializationTraits { +class ARROW_EXPORT FlightMessageReaderImpl : public FlightMessageReader { public: - static grpc::Status Deserialize(ByteBuffer* buffer, IpcPayload* out) { - return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, - "IpcPayload deserialization not implemented"); - } - - static grpc::Status Serialize(const IpcPayload& msg, ByteBuffer* out, - bool* own_buffer) { - size_t total_size = 0; - - DCHECK_LT(msg.metadata->size(), kInt32Max); - const int32_t metadata_size = static_cast(msg.metadata->size()); - - // 1 byte for metadata tag - total_size += 1 + WireFormatLite::LengthDelimitedSize(metadata_size); - - int64_t body_size = 0; - for (const auto& buffer : msg.body_buffers) { - // Buffer may be null when the row length is zero, or when all - // entries are invalid. - if (!buffer) continue; - - body_size += buffer->size(); - - const int64_t remainder = buffer->size() % 8; - if (remainder) { - body_size += 8 - remainder; - } - } - - // 2 bytes for body tag - // Only written when there are body buffers - if (msg.body_length > 0) { - total_size += - 2 + WireFormatLite::LengthDelimitedSize(static_cast(body_size)); - } - - // TODO(wesm): messages over 2GB unlikely to be yet supported - if (total_size > kInt32Max) { - return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, - "Cannot send record batches exceeding 2GB yet"); + FlightMessageReaderImpl(const FlightDescriptor& descriptor, + std::shared_ptr schema, + grpc::ServerReader* reader) + : descriptor_{descriptor}, + schema_{schema}, + reader_{reader}, + stream_finished_{false} {} + + const FlightDescriptor& descriptor() const override { return descriptor_; } + + std::shared_ptr schema() const override { return schema_; } + + Status ReadNext(std::shared_ptr* out) override { + if (stream_finished_) { + *out = nullptr; + return Status::OK(); } - // Allocate slice, assign to output buffer - grpc::Slice slice(total_size); + auto custom_reader = reinterpret_cast*>(reader_); - // XXX(wesm): for debugging - // std::cout << "Writing record batch with total size " << total_size << std::endl; + FlightData data; + // Explicitly specify the override to invoke - otherwise compiler + // may invoke through vtable (not updated by reinterpret_cast) + if (custom_reader->grpc::ServerReader::Read(&data)) { + std::unique_ptr message; - FixedSizeProtoWriter writer(*reinterpret_cast(&slice)); - CodedOutputStream pb_stream(&writer); - - // Write header - WireFormatLite::WriteTag(pb::FlightData::kDataHeaderFieldNumber, - WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &pb_stream); - pb_stream.WriteVarint32(metadata_size); - pb_stream.WriteRawMaybeAliased(msg.metadata->data(), - static_cast(msg.metadata->size())); - - // Don't write tag if there are no body buffers - if (msg.body_length > 0) { - // Write body - WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber, - WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &pb_stream); - pb_stream.WriteVarint32(static_cast(body_size)); - - constexpr uint8_t kPaddingBytes[8] = {0}; - - for (const auto& buffer : msg.body_buffers) { - // Buffer may be null when the row length is zero, or when all - // entries are invalid. - if (!buffer) continue; - - pb_stream.WriteRawMaybeAliased(buffer->data(), static_cast(buffer->size())); - - // Write padding if not multiple of 8 - const int remainder = static_cast(buffer->size() % 8); - if (remainder) { - pb_stream.WriteRawMaybeAliased(kPaddingBytes, 8 - remainder); - } + // Validate IPC message + RETURN_NOT_OK(ipc::Message::Open(data.metadata, data.body, &message)); + if (message->type() == ipc::Message::Type::RECORD_BATCH) { + return ipc::ReadRecordBatch(*message, schema_, out); + } else { + return Status(StatusCode::Invalid, "Unrecognized message in Flight stream"); } + } else { + // Stream is completed + stream_finished_ = true; + *out = nullptr; + return Status::OK(); } - - DCHECK_EQ(static_cast(total_size), pb_stream.ByteCount()); - - // Hand off the slice to the returned ByteBuffer - grpc::ByteBuffer tmp(&slice, 1); - out->Swap(&tmp); - *own_buffer = true; - return grpc::Status::OK; } -}; -} // namespace grpc - -namespace arrow { -namespace flight { - -#define CHECK_ARG_NOT_NULL(VAL, MESSAGE) \ - if (VAL == nullptr) { \ - return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, MESSAGE); \ - } + private: + FlightDescriptor descriptor_; + std::shared_ptr schema_; + grpc::ServerReader* reader_; + bool stream_finished_; +}; // This class glues an implementation of FlightServerBase together with the // gRPC service definition, so the latter is not exposed in the public API @@ -293,7 +209,36 @@ class FlightServiceImpl : public FlightService::Service { grpc::Status DoPut(ServerContext* context, grpc::ServerReader* reader, pb::PutResult* response) { - return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, ""); + // Get metadata + pb::FlightData data; + if (reader->Read(&data)) { + FlightDescriptor descriptor; + std::unique_ptr message; + std::shared_ptr schema; + GRPC_RETURN_NOT_OK(internal::FromProto(data.flight_descriptor(), &descriptor)); + std::shared_ptr header_buf; + std::shared_ptr body_buf; + GRPC_RETURN_NOT_OK(Buffer::FromString(data.data_header(), &header_buf)); + GRPC_RETURN_NOT_OK(Buffer::FromString(data.data_body(), &body_buf)); + GRPC_RETURN_NOT_OK(ipc::Message::Open(header_buf, body_buf, &message)); + + if (!message || message->type() != ipc::Message::Type::SCHEMA) { + return internal::ToGrpcStatus( + Status(StatusCode::Invalid, "DoPut must start with schema/descriptor")); + } else { + GRPC_RETURN_NOT_OK(ipc::ReadSchema(*message, &schema)); + + auto message_reader = std::unique_ptr( + new FlightMessageReaderImpl(descriptor, schema, reader)); + return internal::ToGrpcStatus(server_->DoPut(std::move(message_reader))); + } + } else { + // TODO(lihalite): gRPC doesn't let us distinguish between no + // message sent, and message failed to deserialize. IMO, we + // should add logging around the Status returns in + // serialization-internal.h to make debugging such cases easier. + return grpc::Status::OK; + } } grpc::Status ListActions(ServerContext* context, const pb::Empty* request, @@ -376,6 +321,10 @@ Status FlightServerBase::DoGet(const Ticket& request, return Status::NotImplemented("NYI"); } +Status FlightServerBase::DoPut(std::unique_ptr reader) { + return Status::NotImplemented("NYI"); +} + Status FlightServerBase::DoAction(const Action& action, std::unique_ptr* result) { return Status::NotImplemented("NYI"); diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h index b3b8239132b7a..f975b8619cd48 100644 --- a/cpp/src/arrow/flight/server.h +++ b/cpp/src/arrow/flight/server.h @@ -29,6 +29,7 @@ #include "arrow/flight/types.h" #include "arrow/ipc/dictionary.h" +#include "arrow/record_batch.h" namespace arrow { @@ -81,6 +82,13 @@ class ARROW_EXPORT RecordBatchStream : public FlightDataStream { std::shared_ptr reader_; }; +/// \brief A reader for IPC payloads uploaded by a client +class ARROW_EXPORT FlightMessageReader : public RecordBatchReader { + public: + /// \brief Get the descriptor for this upload. + virtual const FlightDescriptor& descriptor() const = 0; +}; + /// \brief Skeleton RPC server implementation which can be used to create /// custom servers by implementing its abstract methods class ARROW_EXPORT FlightServerBase { @@ -125,7 +133,10 @@ class ARROW_EXPORT FlightServerBase { /// \return Status virtual Status DoGet(const Ticket& request, std::unique_ptr* stream); - // virtual Status DoPut(std::unique_ptr* reader) = 0; + /// \brief Process a stream of IPC payloads sent from a client + /// \param[in] reader a sequence of uploaded record batches + /// \return Status + virtual Status DoPut(std::unique_ptr reader); /// \brief Execute an action, return stream of zero or more results /// \param[in] action the action to execute, with type and body diff --git a/cpp/src/arrow/flight/test-integration-client.cc b/cpp/src/arrow/flight/test-integration-client.cc index 267025a451cc7..a94001ff33713 100644 --- a/cpp/src/arrow/flight/test-integration-client.cc +++ b/cpp/src/arrow/flight/test-integration-client.cc @@ -31,6 +31,7 @@ #include "arrow/io/test-common.h" #include "arrow/ipc/json.h" #include "arrow/record_batch.h" +#include "arrow/table.h" #include "arrow/flight/server.h" #include "arrow/flight/test-util.h" @@ -38,7 +39,6 @@ DEFINE_string(host, "localhost", "Server port to connect to"); DEFINE_int32(port, 31337, "Server port to connect to"); DEFINE_string(path, "", "Resource path to request"); -DEFINE_string(output, "", "Where to write requested resource"); int main(int argc, char** argv) { gflags::SetUsageMessage("Integration testing client for Flight."); @@ -49,9 +49,38 @@ int main(int argc, char** argv) { arrow::flight::FlightDescriptor descr{ arrow::flight::FlightDescriptor::PATH, "", {FLAGS_path}}; + + // 1. Put the data to the server. + std::unique_ptr reader; + std::shared_ptr in_file; + std::cout << "Opening JSON file '" << FLAGS_path << "'" << std::endl; + ABORT_NOT_OK(arrow::io::ReadableFile::Open(FLAGS_path, &in_file)); + + int64_t file_size = 0; + ABORT_NOT_OK(in_file->GetSize(&file_size)); + + std::shared_ptr json_buffer; + ABORT_NOT_OK(in_file->Read(file_size, &json_buffer)); + + ABORT_NOT_OK(arrow::ipc::internal::json::JsonReader::Open(json_buffer, &reader)); + + std::unique_ptr write_stream; + ABORT_NOT_OK(client->DoPut(descr, reader->schema(), &write_stream)); + + std::vector> original_chunks; + for (int i = 0; i < reader->num_record_batches(); i++) { + std::shared_ptr batch; + ABORT_NOT_OK(reader->ReadRecordBatch(i, &batch)); + original_chunks.push_back(batch); + ABORT_NOT_OK(write_stream->WriteRecordBatch(*batch)); + } + ABORT_NOT_OK(write_stream->Close()); + + // 2. Get the ticket for the data. std::unique_ptr info; ABORT_NOT_OK(client->GetFlightInfo(descr, &info)); + // 3. Download the data from the server. std::shared_ptr schema; ABORT_NOT_OK(info->GetSchema(&schema)); @@ -64,19 +93,28 @@ int main(int argc, char** argv) { std::unique_ptr stream; ABORT_NOT_OK(client->DoGet(ticket, schema, &stream)); - std::shared_ptr out_file; - ABORT_NOT_OK(arrow::io::FileOutputStream::Open(FLAGS_output, &out_file)); - std::shared_ptr writer; - ABORT_NOT_OK(arrow::ipc::RecordBatchFileWriter::Open(out_file.get(), schema, &writer)); - + std::vector> retrieved_chunks; std::shared_ptr chunk; while (true) { ABORT_NOT_OK(stream->ReadNext(&chunk)); if (chunk == nullptr) break; - ABORT_NOT_OK(writer->WriteRecordBatch(*chunk)); + retrieved_chunks.push_back(chunk); } - ABORT_NOT_OK(writer->Close()); + // 4. Validate that the data is equal. + + std::shared_ptr original_data; + std::shared_ptr retrieved_data; + + ABORT_NOT_OK( + arrow::Table::FromRecordBatches(reader->schema(), original_chunks, &original_data)); + ABORT_NOT_OK( + arrow::Table::FromRecordBatches(schema, retrieved_chunks, &retrieved_data)); + + if (!original_data->Equals(*retrieved_data)) { + std::cerr << "Data does not match!" << std::endl; + return 1; + } return 0; } diff --git a/cpp/src/arrow/flight/test-integration-server.cc b/cpp/src/arrow/flight/test-integration-server.cc index 80813e7f19a4c..c50e52713b33f 100644 --- a/cpp/src/arrow/flight/test-integration-server.cc +++ b/cpp/src/arrow/flight/test-integration-server.cc @@ -27,6 +27,7 @@ #include "arrow/io/test-common.h" #include "arrow/ipc/json.h" #include "arrow/record_batch.h" +#include "arrow/table.h" #include "arrow/flight/server.h" #include "arrow/flight/test-util.h" @@ -36,57 +37,7 @@ DEFINE_int32(port, 31337, "Server port to listen on"); namespace arrow { namespace flight { -class JsonReaderRecordBatchStream : public FlightDataStream { - public: - explicit JsonReaderRecordBatchStream( - std::unique_ptr&& reader) - : index_(0), pool_(default_memory_pool()), reader_(std::move(reader)) {} - - std::shared_ptr schema() override { return reader_->schema(); } - - Status Next(ipc::internal::IpcPayload* payload) override { - if (index_ >= reader_->num_record_batches()) { - // Signal that iteration is over - payload->metadata = nullptr; - return Status::OK(); - } - - std::shared_ptr batch; - RETURN_NOT_OK(reader_->ReadRecordBatch(index_, &batch)); - index_++; - - if (!batch) { - // Signal that iteration is over - payload->metadata = nullptr; - return Status::OK(); - } else { - return ipc::internal::GetRecordBatchPayload(*batch, pool_, payload); - } - } - - private: - int index_; - MemoryPool* pool_; - std::unique_ptr reader_; -}; - class FlightIntegrationTestServer : public FlightServerBase { - Status ReadJson(const std::string& json_path, - std::unique_ptr* out) { - std::shared_ptr in_file; - std::cout << "Opening JSON file '" << json_path << "'" << std::endl; - RETURN_NOT_OK(io::ReadableFile::Open(json_path, &in_file)); - - int64_t file_size = 0; - RETURN_NOT_OK(in_file->GetSize(&file_size)); - - std::shared_ptr json_buffer; - RETURN_NOT_OK(in_file->Read(file_size, &json_buffer)); - - RETURN_NOT_OK(arrow::ipc::internal::json::JsonReader::Open(json_buffer, out)); - return Status::OK(); - } - Status GetFlightInfo(const FlightDescriptor& request, std::unique_ptr* info) override { if (request.type == FlightDescriptor::PATH) { @@ -94,16 +45,19 @@ class FlightIntegrationTestServer : public FlightServerBase { return Status::Invalid("Invalid path"); } - std::unique_ptr reader; - RETURN_NOT_OK(ReadJson(request.path.back(), &reader)); + auto data = uploaded_chunks.find(request.path[0]); + if (data == uploaded_chunks.end()) { + return Status::KeyError("Could not find flight."); + } + auto flight = data->second; - FlightEndpoint endpoint1({{request.path.back()}, {}}); + FlightEndpoint endpoint1({{request.path[0]}, {}}); FlightInfo::Data flight_data; - RETURN_NOT_OK(internal::SchemaToString(*reader->schema(), &flight_data.schema)); + RETURN_NOT_OK(internal::SchemaToString(*flight->schema(), &flight_data.schema)); flight_data.descriptor = request; flight_data.endpoints = {endpoint1}; - flight_data.total_records = reader->num_record_batches(); + flight_data.total_records = flight->num_rows(); flight_data.total_bytes = -1; FlightInfo value(flight_data); @@ -116,14 +70,44 @@ class FlightIntegrationTestServer : public FlightServerBase { Status DoGet(const Ticket& request, std::unique_ptr* data_stream) override { - std::unique_ptr reader; - RETURN_NOT_OK(ReadJson(request.ticket, &reader)); + auto data = uploaded_chunks.find(request.ticket); + if (data == uploaded_chunks.end()) { + return Status::KeyError("Could not find flight."); + } + auto flight = data->second; - *data_stream = std::unique_ptr( - new JsonReaderRecordBatchStream(std::move(reader))); + *data_stream = std::unique_ptr(new RecordBatchStream( + std::shared_ptr(new TableBatchReader(*flight)))); return Status::OK(); } + + Status DoPut(std::unique_ptr reader) override { + const FlightDescriptor& descriptor = reader->descriptor(); + + if (descriptor.type != FlightDescriptor::DescriptorType::PATH) { + return Status::Invalid("Must specify a path"); + } else if (descriptor.path.size() < 1) { + return Status::Invalid("Must specify a path"); + } + + std::string key = descriptor.path[0]; + + std::vector> retrieved_chunks; + std::shared_ptr chunk; + while (true) { + RETURN_NOT_OK(reader->ReadNext(&chunk)); + if (chunk == nullptr) break; + retrieved_chunks.push_back(chunk); + } + std::shared_ptr retrieved_data; + RETURN_NOT_OK(arrow::Table::FromRecordBatches(reader->schema(), retrieved_chunks, + &retrieved_data)); + uploaded_chunks[key] = retrieved_data; + return Status::OK(); + } + + std::unordered_map> uploaded_chunks; }; } // namespace flight diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index 0362105bbc592..e4251bdd5d21b 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -152,12 +152,6 @@ class FlightInfo { mutable bool reconstructed_schema_; }; -// TODO(wesm): NYI -class ARROW_EXPORT FlightPutWriter { - public: - virtual ~FlightPutWriter() = default; -}; - /// \brief An iterator to FlightInfo instances returned by ListFlights class ARROW_EXPORT FlightListing { public: diff --git a/integration/integration_test.py b/integration/integration_test.py index 0bced26f15acd..cef4e5697b29e 100644 --- a/integration/integration_test.py +++ b/integration/integration_test.py @@ -1009,23 +1009,10 @@ def _compare_flight_implementations(self, producer, consumer): print('Testing file {0}'.format(json_path)) print('==========================================================') - name = os.path.splitext(os.path.basename(json_path))[0] - - file_id = guid()[:8] - with producer.flight_server(): - # Have the client request the file - consumer_file_path = os.path.join( - self.temp_dir, - file_id + '_' + name + '.consumer_requested_file') - consumer.flight_request(producer.FLIGHT_PORT, - json_path, consumer_file_path) - - # Validate the file - print('-- Validating file') - consumer.validate(json_path, consumer_file_path) - - # TODO: also have the client upload the file + # Have the client upload the file, then download and + # compare + consumer.flight_request(producer.FLIGHT_PORT, json_path) class Tester(object): @@ -1053,7 +1040,7 @@ def validate(self, json_path, arrow_path): def flight_server(self): raise NotImplementedError - def flight_request(self, port, json_path, arrow_path): + def flight_request(self, port, json_path): raise NotImplementedError @@ -1122,12 +1109,11 @@ def file_to_stream(self, file_path, stream_path): print(' '.join(cmd)) run_cmd(cmd) - def flight_request(self, port, json_path, arrow_path): + def flight_request(self, port, json_path): cmd = ['java', '-cp', self.ARROW_FLIGHT_JAR, self.ARROW_FLIGHT_CLIENT, '-port', str(port), - '-j', json_path, - '-a', arrow_path] + '-j', json_path] if self.debug: print(' '.join(cmd)) run_cmd(cmd) @@ -1230,15 +1216,14 @@ def flight_server(self): server.terminate() server.wait(5) - def flight_request(self, port, json_path, arrow_path): + def flight_request(self, port, json_path): cmd = self.FLIGHT_CLIENT_CMD + [ '-port=' + str(port), '-path=' + json_path, - '-output=' + arrow_path ] if self.debug: print(' '.join(cmd)) - subprocess.run(cmd) + run_cmd(cmd) class JSTester(Tester): diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java b/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java index af7445fb416fd..dd06f79479064 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java @@ -145,6 +145,7 @@ public void listActions(StreamListener listener) { listener.onNext(new ActionType("get", "pull a stream. Action must be done via standard get mechanism")); listener.onNext(new ActionType("put", "push a stream. Action must be done via standard get mechanism")); listener.onNext(new ActionType("drop", "delete a flight. Action body is a JSON encoded path.")); + listener.onCompleted(); } @Override diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java index 803a56c6c1afe..3c04a19c9de8d 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java @@ -18,7 +18,6 @@ package org.apache.arrow.flight.example.integration; import java.io.File; -import java.io.FileOutputStream; import java.io.IOException; import java.util.List; @@ -30,9 +29,11 @@ import org.apache.arrow.flight.Location; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorLoader; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.dictionary.DictionaryProvider; -import org.apache.arrow.vector.ipc.ArrowFileWriter; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.JsonFileReader; +import org.apache.arrow.vector.util.Validator; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.CommandLineParser; import org.apache.commons.cli.DefaultParser; @@ -48,7 +49,6 @@ class IntegrationTestClient { private IntegrationTestClient() { options = new Options(); - options.addOption("a", "arrow", true, "arrow file"); options.addOption("j", "json", true, "json file"); options.addOption("host", true, "The host to connect to."); options.addOption("port", true, "The port to connect to." ); @@ -64,7 +64,7 @@ public static void main(String[] args) { } } - static void fatalError(String message, Throwable e) { + private static void fatalError(String message, Throwable e) { System.err.println(message); System.err.println(e.getMessage()); LOGGER.error(message, e); @@ -72,37 +72,55 @@ static void fatalError(String message, Throwable e) { } private void run(String[] args) throws ParseException, IOException { - CommandLineParser parser = new DefaultParser(); - CommandLine cmd = parser.parse(options, args, false); - - String fileName = cmd.getOptionValue("arrow"); - if (fileName == null) { - throw new IllegalArgumentException("missing arrow file parameter"); - } - File arrowFile = new File(fileName); - if (arrowFile.exists()) { - throw new IllegalArgumentException("arrow file already exists: " + arrowFile.getAbsolutePath()); - } + final CommandLineParser parser = new DefaultParser(); + final CommandLine cmd = parser.parse(options, args, false); final String host = cmd.getOptionValue("host", "localhost"); final int port = Integer.parseInt(cmd.getOptionValue("port", "31337")); final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); - FlightClient client = new FlightClient(allocator, new Location(host, port)); - FlightInfo info = client.getInfo(FlightDescriptor.path(cmd.getOptionValue("json"))); + final FlightClient client = new FlightClient(allocator, new Location(host, port)); + + final String inputPath = cmd.getOptionValue("j"); + + // 1. Read data from JSON and upload to server. + FlightDescriptor descriptor = FlightDescriptor.path(inputPath); + VectorSchemaRoot jsonRoot; + try (JsonFileReader reader = new JsonFileReader(new File(inputPath), allocator); + VectorSchemaRoot root = VectorSchemaRoot.create(reader.start(), allocator)) { + jsonRoot = VectorSchemaRoot.create(root.getSchema(), allocator); + VectorUnloader unloader = new VectorUnloader(root); + VectorLoader jsonLoader = new VectorLoader(jsonRoot); + FlightClient.ClientStreamListener stream = client.startPut(descriptor, root); + while (reader.read(root)) { + stream.putNext(); + jsonLoader.load(unloader.getRecordBatch()); + root.clear(); + } + stream.completed(); + // Need to call this, or exceptions from the server get swallowed + stream.getResult(); + } + + // 2. Get the ticket for the data. + FlightInfo info = client.getInfo(descriptor); List endpoints = info.getEndpoints(); if (endpoints.isEmpty()) { throw new RuntimeException("No endpoints returned from Flight server."); } + // 3. Download the data from the server. FlightStream stream = client.getStream(info.getEndpoints().get(0).getTicket()); - try (VectorSchemaRoot root = stream.getRoot(); - FileOutputStream fileOutputStream = new FileOutputStream(arrowFile); - ArrowFileWriter arrowWriter = new ArrowFileWriter(root, new DictionaryProvider.MapDictionaryProvider(), - fileOutputStream.getChannel())) { + VectorSchemaRoot downloadedRoot; + try (VectorSchemaRoot root = stream.getRoot()) { + downloadedRoot = VectorSchemaRoot.create(root.getSchema(), allocator); + VectorLoader loader = new VectorLoader(downloadedRoot); + VectorUnloader unloader = new VectorUnloader(root); while (stream.next()) { - arrowWriter.writeBatch(); + loader.load(unloader.getRecordBatch()); } } + + Validator.compareVectorSchemaRoot(jsonRoot, downloadedRoot); } } diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java index 2b78cca93aaef..eff2f5d4126cc 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java @@ -17,34 +17,11 @@ package org.apache.arrow.flight.example.integration; -import java.io.File; -import java.io.FileOutputStream; -import java.nio.charset.StandardCharsets; -import java.util.Collections; -import java.util.concurrent.Callable; - -import org.apache.arrow.flight.Action; -import org.apache.arrow.flight.ActionType; -import org.apache.arrow.flight.Criteria; -import org.apache.arrow.flight.FlightDescriptor; -import org.apache.arrow.flight.FlightEndpoint; -import org.apache.arrow.flight.FlightInfo; -import org.apache.arrow.flight.FlightProducer; -import org.apache.arrow.flight.FlightServer; -import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.Location; -import org.apache.arrow.flight.Result; -import org.apache.arrow.flight.Ticket; -import org.apache.arrow.flight.auth.ServerAuthHandler; -import org.apache.arrow.flight.impl.Flight; +import org.apache.arrow.flight.example.ExampleFlightServer; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.VectorUnloader; -import org.apache.arrow.vector.dictionary.DictionaryProvider; -import org.apache.arrow.vector.ipc.ArrowFileWriter; -import org.apache.arrow.vector.ipc.JsonFileReader; -import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.util.AutoCloseables; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.CommandLineParser; import org.apache.commons.cli.DefaultParser; @@ -52,6 +29,7 @@ import org.apache.commons.cli.ParseException; class IntegrationTestServer { + private static final org.slf4j.Logger LOGGER = org.slf4j.LoggerFactory.getLogger(IntegrationTestServer.class); private final Options options; private IntegrationTestServer() { @@ -62,17 +40,25 @@ private IntegrationTestServer() { private void run(String[] args) throws Exception { CommandLineParser parser = new DefaultParser(); CommandLine cmd = parser.parse(options, args, false); + final int port = Integer.parseInt(cmd.getOptionValue("port", "31337")); final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); - final int port = Integer.parseInt(cmd.getOptionValue("port", "31337")); - try (final IntegrationFlightProducer producer = new IntegrationFlightProducer(allocator); - final FlightServer server = new FlightServer(allocator, port, producer, ServerAuthHandler.NO_OP)) { - server.start(); - // Print out message for integration test script - System.out.println("Server listening on localhost:" + server.getPort()); - while (true) { - Thread.sleep(30000); + final ExampleFlightServer efs = new ExampleFlightServer(allocator, new Location("localhost", port)); + efs.start(); + // Print out message for integration test script + System.out.println("Server listening on localhost:" + port); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + try { + System.out.println("\nExiting..."); + AutoCloseables.close(efs, allocator); + } catch (Exception e) { + e.printStackTrace(); } + })); + + while (true) { + Thread.sleep(30000); } } @@ -80,104 +66,17 @@ public static void main(String[] args) { try { new IntegrationTestServer().run(args); } catch (ParseException e) { - IntegrationTestClient.fatalError("Error parsing arguments", e); + fatalError("Error parsing arguments", e); } catch (Exception e) { - IntegrationTestClient.fatalError("Runtime error", e); + fatalError("Runtime error", e); } } - static class IntegrationFlightProducer implements FlightProducer, AutoCloseable { - private final BufferAllocator allocator; - - IntegrationFlightProducer(BufferAllocator allocator) { - this.allocator = allocator; - } - - @Override - public void close() { - allocator.close(); - } - - @Override - public void getStream(Ticket ticket, ServerStreamListener listener) { - String path = new String(ticket.getBytes(), StandardCharsets.UTF_8); - File inputFile = new File(path); - try (JsonFileReader reader = new JsonFileReader(inputFile, allocator)) { - Schema schema = reader.start(); - try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { - listener.start(root); - while (reader.read(root)) { - listener.putNext(); - } - listener.completed(); - } - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public void listFlights(Criteria criteria, StreamListener listener) { - listener.onCompleted(); - } - - @Override - public FlightInfo getFlightInfo(FlightDescriptor descriptor) { - if (descriptor.isCommand()) { - throw new UnsupportedOperationException("Commands not supported."); - } - if (descriptor.getPath().size() < 1) { - throw new IllegalArgumentException("Must provide a path."); - } - String path = descriptor.getPath().get(0); - File inputFile = new File(path); - try (JsonFileReader reader = new JsonFileReader(inputFile, allocator)) { - Schema schema = reader.start(); - return new FlightInfo(schema, descriptor, - Collections.singletonList(new FlightEndpoint(new Ticket(path.getBytes()), - new Location("localhost", 31338))), - 0, 0); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public Callable acceptPut(FlightStream flightStream) { - return () -> { - if (flightStream.getDescriptor().isCommand()) { - throw new UnsupportedOperationException("Commands not supported."); - } - if (flightStream.getDescriptor().getPath().size() < 1) { - throw new IllegalArgumentException("Must provide a path."); - } - String path = flightStream.getDescriptor().getPath().get(0); - File outputFile = new File(path); - if (!outputFile.createNewFile()) { - throw new IllegalStateException("File already exists."); - } - try (VectorSchemaRoot root = flightStream.getRoot(); - FileOutputStream fileOutputStream = new FileOutputStream(outputFile); - ArrowFileWriter writer = new ArrowFileWriter(root, new DictionaryProvider.MapDictionaryProvider(), - fileOutputStream.getChannel())) { - writer.start(); - while (flightStream.next()) { - writer.writeBatch(); - } - writer.end(); - } - return Flight.PutResult.getDefaultInstance(); - }; - } - - @Override - public Result doAction(Action action) { - throw new UnsupportedOperationException("No actions implemented."); - } - - @Override - public void listActions(StreamListener listener) { - listener.onCompleted(); - } + private static void fatalError(String message, Throwable e) { + System.err.println(message); + System.err.println(e.getMessage()); + LOGGER.error(message, e); + System.exit(1); } + }