diff --git a/cpp/cmake_modules/Findc-ares.cmake b/cpp/cmake_modules/Findc-ares.cmake new file mode 100644 index 0000000000000..1366ce33fa790 --- /dev/null +++ b/cpp/cmake_modules/Findc-ares.cmake @@ -0,0 +1,108 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Tries to find c-ares headers and libraries. +# +# Usage of this module as follows: +# +# find_package(c-ares) +# +# Variables used by this module, they can change the default behaviour and need +# to be set before calling find_package: +# +# CARES_HOME - When set, this path is inspected instead of standard library +# locations as the root of the c-ares installation. +# The environment variable CARES_HOME overrides this variable. +# +# - Find CARES +# This module defines +# CARES_INCLUDE_DIR, directory containing headers +# CARES_SHARED_LIB, path to c-ares's shared library +# CARES_FOUND, whether c-ares has been found + +if( NOT "${CARES_HOME}" STREQUAL "") + file( TO_CMAKE_PATH "${CARES_HOME}" _native_path ) + list( APPEND _cares_roots ${_native_path} ) +elseif ( CARES_HOME ) + list( APPEND _cares_roots ${CARES_HOME} ) +endif() + +if (MSVC) + set(CARES_LIB_NAME cares.lib) +else () + set(CARES_LIB_NAME + ${CMAKE_SHARED_LIBRARY_PREFIX}cares${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(CARES_STATIC_LIB_NAME + ${CMAKE_STATIC_LIBRARY_PREFIX}cares${CMAKE_STATIC_LIBRARY_SUFFIX}) +endif () + +# Try the parameterized roots, if they exist +if (_cares_roots) + find_path(CARES_INCLUDE_DIR NAMES ares.h + PATHS ${_cares_roots} NO_DEFAULT_PATH + PATH_SUFFIXES "include") + find_library(CARES_SHARED_LIB + NAMES ${CARES_LIB_NAME} + PATHS ${_cares_roots} NO_DEFAULT_PATH + PATH_SUFFIXES "lib") + find_library(CARES_STATIC_LIB + NAMES ${CARES_STATIC_LIB_NAME} + PATHS ${_cares_roots} NO_DEFAULT_PATH + PATH_SUFFIXES "lib") +else () + pkg_check_modules(PKG_CARES cares) + if (PKG_CARES_FOUND) + set(CARES_INCLUDE_DIR ${PKG_CARES_INCLUDEDIR}) + find_library(CARES_SHARED_LIB + NAMES ${CARES_LIB_NAME} + PATHS ${PKG_CARES_LIBDIR} NO_DEFAULT_PATH) + else () + find_path(CARES_INCLUDE_DIR NAMES cares.h) + find_library(CARES_SHARED_LIB NAMES ${CARES_LIB_NAME}) + endif () +endif () + +if (CARES_INCLUDE_DIR AND CARES_SHARED_LIB) + set(CARES_FOUND TRUE) +else () + set(CARES_FOUND FALSE) +endif () + +if (CARES_FOUND) + if (NOT CARES_FIND_QUIETLY) + if (CARES_SHARED_LIB) + message(STATUS "Found the c-ares shared library: ${CARES_SHARED_LIB}") + endif () + endif () +else () + if (NOT CARES_FIND_QUIETLY) + set(CARES_ERR_MSG "Could not find the c-ares library. Looked in ") + if ( _cares_roots ) + set(CARES_ERR_MSG "${CARES_ERR_MSG} ${_cares_roots}.") + else () + set(CARES_ERR_MSG "${CARES_ERR_MSG} system search paths.") + endif () + if (CARES_FIND_REQUIRED) + message(FATAL_ERROR "${CARES_ERR_MSG}") + else (CARES_FIND_REQUIRED) + message(STATUS "${CARES_ERR_MSG}") + endif (CARES_FIND_REQUIRED) + endif () +endif () + +mark_as_advanced( + CARES_INCLUDE_DIR + CARES_LIBRARIES + CARES_SHARED_LIB + CARES_STATIC_LIB +) diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index 5ee0ddfd55914..9bd6cb68c70e7 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -1347,12 +1347,7 @@ if (ARROW_WITH_GRPC) BUILD_BYPRODUCTS "${CARES_STATIC_LIB}") else() set(CARES_VENDORED 0) - find_package(c-ares REQUIRED - PATHS ${CARES_HOME} - NO_DEFAULT_PATH) - if(TARGET c-ares::cares) - get_property(CARES_STATIC_LIB TARGET c-ares::cares_static PROPERTY LOCATION) - endif() + find_package(c-ares REQUIRED) endif() message(STATUS "c-ares library: ${CARES_STATIC_LIB}") @@ -1406,15 +1401,15 @@ if (ARROW_WITH_GRPC) set(GRPC_CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} - -DCMAKE_PREFIX_PATH="${GRPC_PREFIX_PATH_ALT_SEP}" - "-DgRPC_CARES_PROVIDER=package" - "-DgRPC_GFLAGS_PROVIDER=package" - "-DgRPC_PROTOBUF_PROVIDER=package" - "-DgRPC_SSL_PROVIDER=package" - "-DgRPC_ZLIB_PROVIDER=package" - "-DCMAKE_CXX_FLAGS=${EP_CXX_FLAGS}" - "-DCMAKE_C_FLAGS=${EP_C_FLAGS}" - "-DCMAKE_INSTALL_PREFIX=${GRPC_PREFIX}" + -DCMAKE_PREFIX_PATH='${GRPC_PREFIX_PATH_ALT_SEP}' + '-DgRPC_CARES_PROVIDER=package' + '-DgRPC_GFLAGS_PROVIDER=package' + '-DgRPC_PROTOBUF_PROVIDER=package' + '-DgRPC_SSL_PROVIDER=package' + '-DgRPC_ZLIB_PROVIDER=package' + '-DCMAKE_CXX_FLAGS=${EP_CXX_FLAGS}' + '-DCMAKE_C_FLAGS=${EP_C_FLAGS}' + '-DCMAKE_INSTALL_PREFIX=${GRPC_PREFIX}' -DCMAKE_INSTALL_LIBDIR=lib -DBUILD_SHARED_LIBS=OFF) diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index b8b4d8d336365..1cbef6cf81808 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -21,6 +21,7 @@ add_custom_target(arrow_flight) ARROW_INSTALL_ALL_HEADERS("arrow/flight") SET(ARROW_FLIGHT_STATIC_LINK_LIBS + protobuf_static grpc_grpcpp_static grpc_grpc_static grpc_gpr_static @@ -69,6 +70,7 @@ set(ARROW_FLIGHT_SRCS Flight.pb.cc Flight.grpc.pb.cc internal.cc + serialization-internal.cc server.cc types.cc ) diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index e25c1875d669f..a58c2b5933225 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -22,12 +22,12 @@ #include #include -#include "google/protobuf/io/coded_stream.h" -#include "google/protobuf/wire_format_lite.h" -#include "grpc/byte_buffer_reader.h" #include "grpcpp/grpcpp.h" +#include "arrow/ipc/dictionary.h" +#include "arrow/ipc/metadata-internal.h" #include "arrow/ipc/reader.h" +#include "arrow/ipc/writer.h" #include "arrow/record_batch.h" #include "arrow/status.h" #include "arrow/type.h" @@ -36,161 +36,11 @@ #include "arrow/flight/Flight.grpc.pb.h" #include "arrow/flight/Flight.pb.h" #include "arrow/flight/internal.h" +#include "arrow/flight/serialization-internal.h" -namespace pb = arrow::flight::protocol; - -namespace arrow { -namespace flight { - -/// Internal, not user-visible type used for memory-efficient reads from gRPC -/// stream -struct FlightData { - /// Used only for puts, may be null - std::unique_ptr descriptor; - - /// Non-length-prefixed Message header as described in format/Message.fbs - std::shared_ptr metadata; - - /// Message body - std::shared_ptr body; -}; - -} // namespace flight -} // namespace arrow - -namespace grpc { - -// Customizations to gRPC for more efficient deserialization of FlightData - -using google::protobuf::internal::WireFormatLite; -using google::protobuf::io::CodedInputStream; - -using arrow::flight::FlightData; - -bool ReadBytesZeroCopy(const std::shared_ptr& source_data, - CodedInputStream* input, std::shared_ptr* out) { - uint32_t length; - if (!input->ReadVarint32(&length)) { - return false; - } - *out = arrow::SliceBuffer(source_data, input->CurrentPosition(), - static_cast(length)); - return input->Skip(static_cast(length)); -} - -// Internal wrapper for gRPC ByteBuffer so its memory can be exposed to Arrow -// consumers with zero-copy -class GrpcBuffer : public arrow::MutableBuffer { - public: - GrpcBuffer(grpc_slice slice, bool incref) - : MutableBuffer(GRPC_SLICE_START_PTR(slice), - static_cast(GRPC_SLICE_LENGTH(slice))), - slice_(incref ? grpc_slice_ref(slice) : slice) {} - - ~GrpcBuffer() override { - // Decref slice - grpc_slice_unref(slice_); - } - - static arrow::Status Wrap(ByteBuffer* cpp_buf, std::shared_ptr* out) { - // These types are guaranteed by static assertions in gRPC to have the same - // in-memory representation - - auto buffer = *reinterpret_cast(cpp_buf); - - // This part below is based on the Flatbuffers gRPC SerializationTraits in - // flatbuffers/grpc.h - - // Check if this is a single uncompressed slice. - if ((buffer->type == GRPC_BB_RAW) && - (buffer->data.raw.compression == GRPC_COMPRESS_NONE) && - (buffer->data.raw.slice_buffer.count == 1)) { - // If it is, then we can reference the `grpc_slice` directly. - grpc_slice slice = buffer->data.raw.slice_buffer.slices[0]; - - // Increment reference count so this memory remains valid - *out = std::make_shared(slice, true); - } else { - // Otherwise, we need to use `grpc_byte_buffer_reader_readall` to read - // `buffer` into a single contiguous `grpc_slice`. The gRPC reader gives - // us back a new slice with the refcount already incremented. - grpc_byte_buffer_reader reader; - if (!grpc_byte_buffer_reader_init(&reader, buffer)) { - return arrow::Status::IOError("Internal gRPC error reading from ByteBuffer"); - } - grpc_slice slice = grpc_byte_buffer_reader_readall(&reader); - grpc_byte_buffer_reader_destroy(&reader); - - // Steal the slice reference - *out = std::make_shared(slice, false); - } - - return arrow::Status::OK(); - } - - private: - grpc_slice slice_; -}; - -// Read internal::FlightData from grpc::ByteBuffer containing FlightData -// protobuf without copying -template <> -class SerializationTraits { - public: - static Status Serialize(const FlightData& msg, ByteBuffer** buffer, bool* own_buffer) { - return Status(StatusCode::UNIMPLEMENTED, - "internal::FlightData serialization not implemented"); - } - - static Status Deserialize(ByteBuffer* buffer, FlightData* out) { - if (!buffer) { - return Status(StatusCode::INTERNAL, "No payload"); - } +using arrow::ipc::internal::IpcPayload; - std::shared_ptr wrapped_buffer; - GRPC_RETURN_NOT_OK(GrpcBuffer::Wrap(buffer, &wrapped_buffer)); - - auto buffer_length = static_cast(wrapped_buffer->size()); - CodedInputStream pb_stream(wrapped_buffer->data(), buffer_length); - - // TODO(wesm): The 2-parameter version of this function is deprecated - pb_stream.SetTotalBytesLimit(buffer_length, -1 /* no threshold */); - - // This is the bytes remaining when using CodedInputStream like this - while (pb_stream.BytesUntilTotalBytesLimit()) { - const uint32_t tag = pb_stream.ReadTag(); - const int field_number = WireFormatLite::GetTagFieldNumber(tag); - switch (field_number) { - case pb::FlightData::kFlightDescriptorFieldNumber: { - pb::FlightDescriptor pb_descriptor; - if (!pb_descriptor.ParseFromCodedStream(&pb_stream)) { - return Status(StatusCode::INTERNAL, "Unable to parse FlightDescriptor"); - } - } break; - case pb::FlightData::kDataHeaderFieldNumber: { - if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->metadata)) { - return Status(StatusCode::INTERNAL, "Unable to read FlightData metadata"); - } - } break; - case pb::FlightData::kDataBodyFieldNumber: { - if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->body)) { - return Status(StatusCode::INTERNAL, "Unable to read FlightData body"); - } - } break; - default: - DCHECK(false) << "cannot happen"; - } - } - buffer->Clear(); - - // TODO(wesm): Where and when should we verify that the FlightData is not - // malformed or missing components? - - return Status::OK; - } -}; - -} // namespace grpc +namespace pb = arrow::flight::protocol; namespace arrow { namespace flight { @@ -225,9 +75,12 @@ class FlightStreamReader : public RecordBatchReader { } // For customizing read path for better memory/serialization efficiency + // XXX this cast is undefined behavior auto custom_reader = reinterpret_cast*>(stream_.get()); - if (custom_reader->Read(&data)) { + // Explicitly specify the override to invoke - otherwise compiler + // may invoke through vtable (not updated by reinterpret_cast) + if (custom_reader->grpc::ClientReader::Read(&data)) { std::unique_ptr message; // Validate IPC message @@ -259,6 +112,82 @@ class FlightStreamReader : public RecordBatchReader { std::unique_ptr> stream_; }; +class FlightClient; + +/// \brief A RecordBatchWriter implementation that writes to a Flight +/// DoPut stream. +class FlightPutWriter::FlightPutWriterImpl : public ipc::RecordBatchWriter { + public: + explicit FlightPutWriterImpl(std::unique_ptr rpc, + const FlightDescriptor& descriptor, + const std::shared_ptr& schema, + MemoryPool* pool = default_memory_pool()) + : rpc_(std::move(rpc)), descriptor_(descriptor), schema_(schema), pool_(pool) {} + + Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override { + IpcPayload payload; + RETURN_NOT_OK(ipc::internal::GetRecordBatchPayload(batch, pool_, &payload)); + // XXX this cast is undefined behavior + auto custom_writer = reinterpret_cast*>(writer_.get()); + // Explicitly specify the override to invoke - otherwise compiler + // may invoke through vtable (not updated by reinterpret_cast) + if (!custom_writer->grpc::ClientWriter::Write(payload, + grpc::WriteOptions())) { + std::stringstream ss; + ss << "Could not write record batch to stream: " + << rpc_->context.debug_error_string(); + return Status::IOError(ss.str()); + } + return Status::OK(); + } + + Status Close() override { + bool finished_writes = writer_->WritesDone(); + RETURN_NOT_OK(internal::FromGrpcStatus(writer_->Finish())); + if (!finished_writes) { + return Status::UnknownError( + "Could not finish writing record batches before closing"); + } + return Status::OK(); + } + + void set_memory_pool(MemoryPool* pool) override { pool_ = pool; } + + private: + /// \brief Set the gRPC writer backing this Flight stream. + /// \param [in] writer the gRPC writer + void set_stream(std::unique_ptr> writer) { + writer_ = std::move(writer); + } + + // TODO: there isn't a way to access this as a user. + protocol::PutResult response; + std::unique_ptr rpc_; + FlightDescriptor descriptor_; + std::shared_ptr schema_; + std::unique_ptr> writer_; + MemoryPool* pool_; + + // We need to reference some fields + friend class FlightClient; +}; + +FlightPutWriter::~FlightPutWriter() {} + +FlightPutWriter::FlightPutWriter(std::unique_ptr impl) { + impl_ = std::move(impl); +} + +Status FlightPutWriter::WriteRecordBatch(const RecordBatch& batch, bool allow_64bit) { + return impl_->WriteRecordBatch(batch, allow_64bit); +} + +Status FlightPutWriter::Close() { return impl_->Close(); } + +void FlightPutWriter::set_memory_pool(MemoryPool* pool) { + return impl_->set_memory_pool(pool); +} + class FlightClient::FlightClientImpl { public: Status Connect(const std::string& host, int port) { @@ -364,8 +293,38 @@ class FlightClient::FlightClientImpl { return Status::OK(); } - Status DoPut(std::unique_ptr* stream) { - return Status::NotImplemented("DoPut"); + Status DoPut(const FlightDescriptor& descriptor, const std::shared_ptr& schema, + std::unique_ptr* stream) { + std::unique_ptr rpc(new ClientRpc); + std::unique_ptr out( + new FlightPutWriter::FlightPutWriterImpl(std::move(rpc), descriptor, schema)); + std::unique_ptr> write_stream( + stub_->DoPut(&out->rpc_->context, &out->response)); + + // First write the descriptor and schema to the stream. + pb::FlightData descriptor_message; + RETURN_NOT_OK( + internal::ToProto(descriptor, descriptor_message.mutable_flight_descriptor())); + + std::shared_ptr header_buf; + RETURN_NOT_OK(Buffer::FromString("", &header_buf)); + ipc::DictionaryMemo dictionary_memo; + RETURN_NOT_OK(ipc::SerializeSchema(*schema, out->pool_, &header_buf)); + RETURN_NOT_OK( + ipc::internal::WriteSchemaMessage(*schema, &dictionary_memo, &header_buf)); + descriptor_message.set_data_header(header_buf->ToString()); + + if (!write_stream->Write(descriptor_message, grpc::WriteOptions())) { + std::stringstream ss; + ss << "Could not write descriptor and schema to stream: " + << rpc->context.debug_error_string(); + return Status::IOError(ss.str()); + } + + out->set_stream(std::move(write_stream)); + *stream = + std::unique_ptr(new FlightPutWriter(std::move(out))); + return Status::OK(); } private: @@ -410,9 +369,10 @@ Status FlightClient::DoGet(const Ticket& ticket, const std::shared_ptr& return impl_->DoGet(ticket, schema, stream); } -Status FlightClient::DoPut(const Schema& schema, - std::unique_ptr* stream) { - return Status::NotImplemented("DoPut"); +Status FlightClient::DoPut(const FlightDescriptor& descriptor, + const std::shared_ptr& schema, + std::unique_ptr* stream) { + return impl_->DoPut(descriptor, schema, stream); } } // namespace flight diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index 53bb1755b2995..ef960417b024a 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -24,6 +24,7 @@ #include #include +#include "arrow/ipc/writer.h" #include "arrow/status.h" #include "arrow/util/visibility.h" @@ -37,6 +38,8 @@ class Schema; namespace flight { +class FlightPutWriter; + /// \brief Client class for Arrow Flight RPC services (gRPC-based). /// API experimental for now class ARROW_EXPORT FlightClient { @@ -86,7 +89,7 @@ class ARROW_EXPORT FlightClient { /// \brief Given a flight ticket and schema, request to be sent the /// stream. Returns record batch stream reader - /// \param[in] ticket + /// \param[in] ticket The flight ticket to use /// \param[in] schema the schema of the stream data as computed by /// GetFlightInfo /// \param[out] stream the returned RecordBatchReader @@ -94,12 +97,15 @@ class ARROW_EXPORT FlightClient { Status DoGet(const Ticket& ticket, const std::shared_ptr& schema, std::unique_ptr* stream); - /// \brief Initiate DoPut RPC, returns FlightPutWriter interface to - /// write. Not yet implemented - /// \param[in] schema the schema of the stream data - /// \param[out] stream the created stream to write record batches to + /// \brief Upload data to a Flight described by the given + /// descriptor. The caller must call Close() on the returned stream + /// once they are done writing. + /// \param[in] descriptor the descriptor of the stream + /// \param[in] schema the schema for the data to upload + /// \param[out] stream a writer to write record batches to /// \return Status - Status DoPut(const Schema& schema, std::unique_ptr* stream); + Status DoPut(const FlightDescriptor& descriptor, const std::shared_ptr& schema, + std::unique_ptr* stream); private: FlightClient(); @@ -107,5 +113,22 @@ class ARROW_EXPORT FlightClient { std::unique_ptr impl_; }; +/// \brief An interface to upload record batches to a Flight server +class ARROW_EXPORT FlightPutWriter : public ipc::RecordBatchWriter { + public: + ~FlightPutWriter(); + + Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override; + Status Close() override; + void set_memory_pool(MemoryPool* pool) override; + + private: + class FlightPutWriterImpl; + explicit FlightPutWriter(std::unique_ptr impl); + std::unique_ptr impl_; + + friend class FlightClient; +}; + } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/internal.cc b/cpp/src/arrow/flight/internal.cc index b4c6b2addcc11..a614450e8d0a0 100644 --- a/cpp/src/arrow/flight/internal.cc +++ b/cpp/src/arrow/flight/internal.cc @@ -131,6 +131,21 @@ void ToProto(const Ticket& ticket, pb::Ticket* pb_ticket) { pb_ticket->set_ticket(ticket.ticket); } +// FlightData + +Status FromProto(const pb::FlightData& pb_data, FlightDescriptor* descriptor, + std::unique_ptr* message) { + RETURN_NOT_OK(internal::FromProto(pb_data.flight_descriptor(), descriptor)); + const std::string& header = pb_data.data_header(); + const std::string& body = pb_data.data_body(); + std::shared_ptr header_buf = Buffer::Wrap(header.data(), header.size()); + std::shared_ptr body_buf = Buffer::Wrap(body.data(), body.size()); + if (header_buf == nullptr || body_buf == nullptr) { + return Status::UnknownError("Could not create buffers from protobuf"); + } + return ipc::Message::Open(header_buf, body_buf, message); +} + // FlightEndpoint Status FromProto(const pb::FlightEndpoint& pb_endpoint, FlightEndpoint* endpoint) { @@ -156,7 +171,7 @@ Status FromProto(const pb::FlightDescriptor& pb_descriptor, FlightDescriptor* descriptor) { if (pb_descriptor.type() == pb::FlightDescriptor::PATH) { descriptor->type = FlightDescriptor::PATH; - descriptor->path.resize(pb_descriptor.path_size()); + descriptor->path.reserve(pb_descriptor.path_size()); for (int i = 0; i < pb_descriptor.path_size(); ++i) { descriptor->path.emplace_back(pb_descriptor.path(i)); } diff --git a/cpp/src/arrow/flight/internal.h b/cpp/src/arrow/flight/internal.h index bae1eedfa9c66..7f9bda138cbb1 100644 --- a/cpp/src/arrow/flight/internal.h +++ b/cpp/src/arrow/flight/internal.h @@ -57,6 +57,8 @@ Status FromProto(const pb::Result& pb_result, Result* result); Status FromProto(const pb::Criteria& pb_criteria, Criteria* criteria); Status FromProto(const pb::Location& pb_location, Location* location); Status FromProto(const pb::Ticket& pb_ticket, Ticket* ticket); +Status FromProto(const pb::FlightData& pb_data, FlightDescriptor* descriptor, + std::unique_ptr* message); Status FromProto(const pb::FlightDescriptor& pb_descr, FlightDescriptor* descr); Status FromProto(const pb::FlightEndpoint& pb_endpoint, FlightEndpoint* endpoint); Status FromProto(const pb::FlightGetInfo& pb_info, FlightInfo::Data* info); diff --git a/cpp/src/arrow/flight/serialization-internal.cc b/cpp/src/arrow/flight/serialization-internal.cc new file mode 100644 index 0000000000000..194a7b5bc0c30 --- /dev/null +++ b/cpp/src/arrow/flight/serialization-internal.cc @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/serialization-internal.h" + +namespace arrow { +namespace flight { +namespace internal { + +bool ReadBytesZeroCopy(const std::shared_ptr& source_data, + CodedInputStream* input, std::shared_ptr* out) { + uint32_t length; + if (!input->ReadVarint32(&length)) { + return false; + } + *out = arrow::SliceBuffer(source_data, input->CurrentPosition(), + static_cast(length)); + return input->Skip(static_cast(length)); +} + +} // namespace internal +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/serialization-internal.h b/cpp/src/arrow/flight/serialization-internal.h new file mode 100644 index 0000000000000..d4254d606d40f --- /dev/null +++ b/cpp/src/arrow/flight/serialization-internal.h @@ -0,0 +1,345 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// (De)serialization utilities that hook into gRPC, efficiently +// handling Arrow-encoded data in a gRPC call. + +#pragma once + +#include +#include + +#include "google/protobuf/io/coded_stream.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/wire_format_lite.h" +#include "grpc/byte_buffer_reader.h" +#include "grpcpp/grpcpp.h" + +#include "arrow/ipc/writer.h" +#include "arrow/record_batch.h" +#include "arrow/status.h" +#include "arrow/util/logging.h" + +#include "arrow/flight/Flight.grpc.pb.h" +#include "arrow/flight/Flight.pb.h" +#include "arrow/flight/internal.h" +#include "arrow/flight/types.h" + +namespace pb = arrow::flight::protocol; + +using arrow::ipc::internal::IpcPayload; + +constexpr int64_t kInt32Max = std::numeric_limits::max(); + +namespace arrow { +namespace flight { + +/// Internal, not user-visible type used for memory-efficient reads from gRPC +/// stream +struct FlightData { + /// Used only for puts, may be null + std::unique_ptr descriptor; + + /// Non-length-prefixed Message header as described in format/Message.fbs + std::shared_ptr metadata; + + /// Message body + std::shared_ptr body; +}; + +namespace internal { + +using google::protobuf::io::CodedInputStream; +using google::protobuf::io::CodedOutputStream; + +// More efficient writing of FlightData to gRPC output buffer +// Implementation of ZeroCopyOutputStream that writes to a fixed-size buffer +class FixedSizeProtoWriter : public ::google::protobuf::io::ZeroCopyOutputStream { + public: + explicit FixedSizeProtoWriter(grpc_slice slice) + : slice_(slice), + bytes_written_(0), + total_size_(static_cast(GRPC_SLICE_LENGTH(slice))) {} + + bool Next(void** data, int* size) override { + // Consume the whole slice + *data = GRPC_SLICE_START_PTR(slice_) + bytes_written_; + *size = total_size_ - bytes_written_; + bytes_written_ = total_size_; + return true; + } + + void BackUp(int count) override { bytes_written_ -= count; } + + int64_t ByteCount() const override { return bytes_written_; } + + private: + grpc_slice slice_; + int bytes_written_; + int total_size_; +}; + +bool ReadBytesZeroCopy(const std::shared_ptr& source_data, + CodedInputStream* input, std::shared_ptr* out); + +// Internal wrapper for gRPC ByteBuffer so its memory can be exposed to Arrow +// consumers with zero-copy +class GrpcBuffer : public arrow::MutableBuffer { + public: + GrpcBuffer(grpc_slice slice, bool incref) + : MutableBuffer(GRPC_SLICE_START_PTR(slice), + static_cast(GRPC_SLICE_LENGTH(slice))), + slice_(incref ? grpc_slice_ref(slice) : slice) {} + + ~GrpcBuffer() override { + // Decref slice + grpc_slice_unref(slice_); + } + + static arrow::Status Wrap(grpc::ByteBuffer* cpp_buf, + std::shared_ptr* out) { + // These types are guaranteed by static assertions in gRPC to have the same + // in-memory representation + + auto buffer = *reinterpret_cast(cpp_buf); + + // This part below is based on the Flatbuffers gRPC SerializationTraits in + // flatbuffers/grpc.h + + // Check if this is a single uncompressed slice. + if ((buffer->type == GRPC_BB_RAW) && + (buffer->data.raw.compression == GRPC_COMPRESS_NONE) && + (buffer->data.raw.slice_buffer.count == 1)) { + // If it is, then we can reference the `grpc_slice` directly. + grpc_slice slice = buffer->data.raw.slice_buffer.slices[0]; + + // Increment reference count so this memory remains valid + *out = std::make_shared(slice, true); + } else { + // Otherwise, we need to use `grpc_byte_buffer_reader_readall` to read + // `buffer` into a single contiguous `grpc_slice`. The gRPC reader gives + // us back a new slice with the refcount already incremented. + grpc_byte_buffer_reader reader; + if (!grpc_byte_buffer_reader_init(&reader, buffer)) { + return arrow::Status::IOError("Internal gRPC error reading from ByteBuffer"); + } + grpc_slice slice = grpc_byte_buffer_reader_readall(&reader); + grpc_byte_buffer_reader_destroy(&reader); + + // Steal the slice reference + *out = std::make_shared(slice, false); + } + + return arrow::Status::OK(); + } + + private: + grpc_slice slice_; +}; + +} // namespace internal + +} // namespace flight +} // namespace arrow + +namespace grpc { + +using arrow::flight::FlightData; +using arrow::flight::internal::FixedSizeProtoWriter; +using arrow::flight::internal::GrpcBuffer; +using arrow::flight::internal::ReadBytesZeroCopy; + +using google::protobuf::internal::WireFormatLite; +using google::protobuf::io::CodedInputStream; +using google::protobuf::io::CodedOutputStream; + +// Helper to log status code, as gRPC doesn't expose why +// (de)serialization fails +inline Status FailSerialization(Status status) { + if (!status.ok()) { + ARROW_LOG(WARNING) << "Error deserializing Flight message: " + << status.error_message(); + } + return status; +} + +inline arrow::Status FailSerialization(arrow::Status status) { + if (!status.ok()) { + ARROW_LOG(WARNING) << "Error deserializing Flight message: " << status.ToString(); + } + return status; +} + +// Read internal::FlightData from grpc::ByteBuffer containing FlightData +// protobuf without copying +template <> +class SerializationTraits { + public: + static Status Serialize(const FlightData& msg, ByteBuffer** buffer, bool* own_buffer) { + return FailSerialization(Status( + StatusCode::UNIMPLEMENTED, "internal::FlightData serialization not implemented")); + } + + static Status Deserialize(ByteBuffer* buffer, FlightData* out) { + if (!buffer) { + return FailSerialization(Status(StatusCode::INTERNAL, "No payload")); + } + + std::shared_ptr wrapped_buffer; + GRPC_RETURN_NOT_OK(FailSerialization(GrpcBuffer::Wrap(buffer, &wrapped_buffer))); + + auto buffer_length = static_cast(wrapped_buffer->size()); + CodedInputStream pb_stream(wrapped_buffer->data(), buffer_length); + + // TODO(wesm): The 2-parameter version of this function is deprecated + pb_stream.SetTotalBytesLimit(buffer_length, -1 /* no threshold */); + + // This is the bytes remaining when using CodedInputStream like this + while (pb_stream.BytesUntilTotalBytesLimit()) { + const uint32_t tag = pb_stream.ReadTag(); + const int field_number = WireFormatLite::GetTagFieldNumber(tag); + switch (field_number) { + case pb::FlightData::kFlightDescriptorFieldNumber: { + pb::FlightDescriptor pb_descriptor; + if (!pb_descriptor.ParseFromCodedStream(&pb_stream)) { + return FailSerialization( + Status(StatusCode::INTERNAL, "Unable to parse FlightDescriptor")); + } + } break; + case pb::FlightData::kDataHeaderFieldNumber: { + if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->metadata)) { + return FailSerialization( + Status(StatusCode::INTERNAL, "Unable to read FlightData metadata")); + } + } break; + case pb::FlightData::kDataBodyFieldNumber: { + if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->body)) { + return FailSerialization( + Status(StatusCode::INTERNAL, "Unable to read FlightData body")); + } + } break; + default: + DCHECK(false) << "cannot happen"; + } + } + buffer->Clear(); + + // TODO(wesm): Where and when should we verify that the FlightData is not + // malformed or missing components? + + return Status::OK; + } +}; + +// Write FlightData to a grpc::ByteBuffer without extra copying +template <> +class SerializationTraits { + public: + static grpc::Status Deserialize(ByteBuffer* buffer, IpcPayload* out) { + return FailSerialization(grpc::Status(grpc::StatusCode::UNIMPLEMENTED, + "IpcPayload deserialization not implemented")); + } + + static grpc::Status Serialize(const IpcPayload& msg, ByteBuffer* out, + bool* own_buffer) { + size_t total_size = 0; + + DCHECK_LT(msg.metadata->size(), kInt32Max); + const int32_t metadata_size = static_cast(msg.metadata->size()); + + // 1 byte for metadata tag + total_size += 1 + WireFormatLite::LengthDelimitedSize(metadata_size); + + int64_t body_size = 0; + for (const auto& buffer : msg.body_buffers) { + // Buffer may be null when the row length is zero, or when all + // entries are invalid. + if (!buffer) continue; + + body_size += buffer->size(); + + const int64_t remainder = buffer->size() % 8; + if (remainder) { + body_size += 8 - remainder; + } + } + + // 2 bytes for body tag + // Only written when there are body buffers + if (msg.body_length > 0) { + total_size += + 2 + WireFormatLite::LengthDelimitedSize(static_cast(body_size)); + } + + // TODO(wesm): messages over 2GB unlikely to be yet supported + if (total_size > kInt32Max) { + return FailSerialization( + grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "Cannot send record batches exceeding 2GB yet")); + } + + // Allocate slice, assign to output buffer + grpc::Slice slice(total_size); + + // XXX(wesm): for debugging + // std::cout << "Writing record batch with total size " << total_size << std::endl; + + FixedSizeProtoWriter writer(*reinterpret_cast(&slice)); + CodedOutputStream pb_stream(&writer); + + // Write header + WireFormatLite::WriteTag(pb::FlightData::kDataHeaderFieldNumber, + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &pb_stream); + pb_stream.WriteVarint32(metadata_size); + pb_stream.WriteRawMaybeAliased(msg.metadata->data(), + static_cast(msg.metadata->size())); + + // Don't write tag if there are no body buffers + if (msg.body_length > 0) { + // Write body + WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber, + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &pb_stream); + pb_stream.WriteVarint32(static_cast(body_size)); + + constexpr uint8_t kPaddingBytes[8] = {0}; + + for (const auto& buffer : msg.body_buffers) { + // Buffer may be null when the row length is zero, or when all + // entries are invalid. + if (!buffer) continue; + + pb_stream.WriteRawMaybeAliased(buffer->data(), static_cast(buffer->size())); + + // Write padding if not multiple of 8 + const int remainder = static_cast(buffer->size() % 8); + if (remainder) { + pb_stream.WriteRawMaybeAliased(kPaddingBytes, 8 - remainder); + } + } + } + + DCHECK_EQ(static_cast(total_size), pb_stream.ByteCount()); + + // Hand off the slice to the returned ByteBuffer + grpc::ByteBuffer tmp(&slice, 1); + out->Swap(&tmp); + *own_buffer = true; + return grpc::Status::OK; + } +}; + +} // namespace grpc diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index 018c079501f2f..ac5b53532866f 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -18,15 +18,12 @@ #include "arrow/flight/server.h" #include -#include #include #include -#include "google/protobuf/io/coded_stream.h" -#include "google/protobuf/io/zero_copy_stream.h" -#include "google/protobuf/wire_format_lite.h" #include "grpcpp/grpcpp.h" +#include "arrow/ipc/reader.h" #include "arrow/ipc/writer.h" #include "arrow/record_batch.h" #include "arrow/status.h" @@ -35,6 +32,7 @@ #include "arrow/flight/Flight.grpc.pb.h" #include "arrow/flight/Flight.pb.h" #include "arrow/flight/internal.h" +#include "arrow/flight/serialization-internal.h" #include "arrow/flight/types.h" using FlightService = arrow::flight::protocol::FlightService; @@ -47,145 +45,64 @@ using ServerWriter = grpc::ServerWriter; namespace pb = arrow::flight::protocol; -constexpr int64_t kInt32Max = std::numeric_limits::max(); - -namespace grpc { - -using google::protobuf::internal::WireFormatLite; -using google::protobuf::io::CodedOutputStream; +namespace arrow { +namespace flight { -// More efficient writing of FlightData to gRPC output buffer -// Implementation of ZeroCopyOutputStream that writes to a fixed-size buffer -class FixedSizeProtoWriter : public ::google::protobuf::io::ZeroCopyOutputStream { - public: - explicit FixedSizeProtoWriter(grpc_slice slice) - : slice_(slice), - bytes_written_(0), - total_size_(static_cast(GRPC_SLICE_LENGTH(slice))) {} - - bool Next(void** data, int* size) override { - // Consume the whole slice - *data = GRPC_SLICE_START_PTR(slice_) + bytes_written_; - *size = total_size_ - bytes_written_; - bytes_written_ = total_size_; - return true; +#define CHECK_ARG_NOT_NULL(VAL, MESSAGE) \ + if (VAL == nullptr) { \ + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, MESSAGE); \ } - void BackUp(int count) override { bytes_written_ -= count; } - - int64_t ByteCount() const override { return bytes_written_; } - - private: - grpc_slice slice_; - int bytes_written_; - int total_size_; -}; - -// Write FlightData to a grpc::ByteBuffer without extra copying -template <> -class SerializationTraits { +class FlightMessageReaderImpl : public FlightMessageReader { public: - static grpc::Status Deserialize(ByteBuffer* buffer, IpcPayload* out) { - return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, - "IpcPayload deserialization not implemented"); - } - - static grpc::Status Serialize(const IpcPayload& msg, ByteBuffer* out, - bool* own_buffer) { - size_t total_size = 0; - - DCHECK_LT(msg.metadata->size(), kInt32Max); - const int32_t metadata_size = static_cast(msg.metadata->size()); - - // 1 byte for metadata tag - total_size += 1 + WireFormatLite::LengthDelimitedSize(metadata_size); - - int64_t body_size = 0; - for (const auto& buffer : msg.body_buffers) { - // Buffer may be null when the row length is zero, or when all - // entries are invalid. - if (!buffer) continue; - - body_size += buffer->size(); - - const int64_t remainder = buffer->size() % 8; - if (remainder) { - body_size += 8 - remainder; - } - } - - // 2 bytes for body tag - // Only written when there are body buffers - if (msg.body_length > 0) { - total_size += - 2 + WireFormatLite::LengthDelimitedSize(static_cast(body_size)); + FlightMessageReaderImpl(const FlightDescriptor& descriptor, + std::shared_ptr schema, + grpc::ServerReader* reader) + : descriptor_(descriptor), + schema_(schema), + reader_(reader), + stream_finished_(false) {} + + const FlightDescriptor& descriptor() const override { return descriptor_; } + + std::shared_ptr schema() const override { return schema_; } + + Status ReadNext(std::shared_ptr* out) override { + if (stream_finished_) { + *out = nullptr; + return Status::OK(); } - // TODO(wesm): messages over 2GB unlikely to be yet supported - if (total_size > kInt32Max) { - return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, - "Cannot send record batches exceeding 2GB yet"); - } - - // Allocate slice, assign to output buffer - grpc::Slice slice(total_size); - - // XXX(wesm): for debugging - // std::cout << "Writing record batch with total size " << total_size << std::endl; - - FixedSizeProtoWriter writer(*reinterpret_cast(&slice)); - CodedOutputStream pb_stream(&writer); - - // Write header - WireFormatLite::WriteTag(pb::FlightData::kDataHeaderFieldNumber, - WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &pb_stream); - pb_stream.WriteVarint32(metadata_size); - pb_stream.WriteRawMaybeAliased(msg.metadata->data(), - static_cast(msg.metadata->size())); - - // Don't write tag if there are no body buffers - if (msg.body_length > 0) { - // Write body - WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber, - WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &pb_stream); - pb_stream.WriteVarint32(static_cast(body_size)); - - constexpr uint8_t kPaddingBytes[8] = {0}; - - for (const auto& buffer : msg.body_buffers) { - // Buffer may be null when the row length is zero, or when all - // entries are invalid. - if (!buffer) continue; - - pb_stream.WriteRawMaybeAliased(buffer->data(), static_cast(buffer->size())); - - // Write padding if not multiple of 8 - const int remainder = static_cast(buffer->size() % 8); - if (remainder) { - pb_stream.WriteRawMaybeAliased(kPaddingBytes, 8 - remainder); - } + // XXX this cast is undefined behavior + auto custom_reader = reinterpret_cast*>(reader_); + + FlightData data; + // Explicitly specify the override to invoke - otherwise compiler + // may invoke through vtable (not updated by reinterpret_cast) + if (custom_reader->grpc::ServerReader::Read(&data)) { + std::unique_ptr message; + + // Validate IPC message + RETURN_NOT_OK(ipc::Message::Open(data.metadata, data.body, &message)); + if (message->type() == ipc::Message::Type::RECORD_BATCH) { + return ipc::ReadRecordBatch(*message, schema_, out); + } else { + return Status(StatusCode::Invalid, "Unrecognized message in Flight stream"); } + } else { + // Stream is completed + stream_finished_ = true; + *out = nullptr; + return Status::OK(); } - - DCHECK_EQ(static_cast(total_size), pb_stream.ByteCount()); - - // Hand off the slice to the returned ByteBuffer - grpc::ByteBuffer tmp(&slice, 1); - out->Swap(&tmp); - *own_buffer = true; - return grpc::Status::OK; } -}; - -} // namespace grpc - -namespace arrow { -namespace flight { -#define CHECK_ARG_NOT_NULL(VAL, MESSAGE) \ - if (VAL == nullptr) { \ - return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, MESSAGE); \ - } + private: + FlightDescriptor descriptor_; + std::shared_ptr schema_; + grpc::ServerReader* reader_; + bool stream_finished_; +}; // This class glues an implementation of FlightServerBase together with the // gRPC service definition, so the latter is not exposed in the public API @@ -268,6 +185,7 @@ class FlightServiceImpl : public FlightService::Service { GRPC_RETURN_NOT_OK(server_->DoGet(ticket, &data_stream)); // Requires ServerWriter customization in grpc_customizations.h + // XXX this cast is undefined behavior auto custom_writer = reinterpret_cast*>(writer); // Write the schema as the first message in the stream @@ -276,7 +194,10 @@ class FlightServiceImpl : public FlightService::Service { ipc::DictionaryMemo dictionary_memo; GRPC_RETURN_NOT_OK(ipc::internal::GetSchemaPayload( *data_stream->schema(), pool, &dictionary_memo, &schema_payload)); - custom_writer->Write(schema_payload, grpc::WriteOptions()); + // Explicitly specify the override to invoke - otherwise compiler + // may invoke through vtable (not updated by reinterpret_cast) + custom_writer->grpc::ServerWriter::Write(schema_payload, + grpc::WriteOptions()); while (true) { IpcPayload payload; @@ -293,7 +214,30 @@ class FlightServiceImpl : public FlightService::Service { grpc::Status DoPut(ServerContext* context, grpc::ServerReader* reader, pb::PutResult* response) { - return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, ""); + // Get metadata + pb::FlightData data; + if (reader->Read(&data)) { + FlightDescriptor descriptor; + // Message only lives as long as data + std::unique_ptr message; + GRPC_RETURN_NOT_OK(internal::FromProto(data, &descriptor, &message)); + + if (!message || message->type() != ipc::Message::Type::SCHEMA) { + return internal::ToGrpcStatus( + Status(StatusCode::Invalid, "DoPut must start with schema/descriptor")); + } else { + std::shared_ptr schema; + GRPC_RETURN_NOT_OK(ipc::ReadSchema(*message, &schema)); + + auto message_reader = std::unique_ptr( + new FlightMessageReaderImpl(descriptor, schema, reader)); + return internal::ToGrpcStatus(server_->DoPut(std::move(message_reader))); + } + } else { + return internal::ToGrpcStatus( + Status(StatusCode::Invalid, + "Client provided malformed message or did not provide message")); + } } grpc::Status ListActions(ServerContext* context, const pb::Empty* request, @@ -376,6 +320,10 @@ Status FlightServerBase::DoGet(const Ticket& request, return Status::NotImplemented("NYI"); } +Status FlightServerBase::DoPut(std::unique_ptr reader) { + return Status::NotImplemented("NYI"); +} + Status FlightServerBase::DoAction(const Action& action, std::unique_ptr* result) { return Status::NotImplemented("NYI"); diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h index b3b8239132b7a..b2e8b02be8e7d 100644 --- a/cpp/src/arrow/flight/server.h +++ b/cpp/src/arrow/flight/server.h @@ -29,6 +29,7 @@ #include "arrow/flight/types.h" #include "arrow/ipc/dictionary.h" +#include "arrow/record_batch.h" namespace arrow { @@ -68,9 +69,9 @@ class ARROW_EXPORT FlightDataStream { /// \brief A basic implementation of FlightDataStream that will provide /// a sequence of FlightData messages to be written to a gRPC stream -/// \param[in] reader produces a sequence of record batches class ARROW_EXPORT RecordBatchStream : public FlightDataStream { public: + /// \param[in] reader produces a sequence of record batches explicit RecordBatchStream(const std::shared_ptr& reader); std::shared_ptr schema() override; @@ -81,6 +82,13 @@ class ARROW_EXPORT RecordBatchStream : public FlightDataStream { std::shared_ptr reader_; }; +/// \brief A reader for IPC payloads uploaded by a client +class ARROW_EXPORT FlightMessageReader : public RecordBatchReader { + public: + /// \brief Get the descriptor for this upload. + virtual const FlightDescriptor& descriptor() const = 0; +}; + /// \brief Skeleton RPC server implementation which can be used to create /// custom servers by implementing its abstract methods class ARROW_EXPORT FlightServerBase { @@ -90,8 +98,7 @@ class ARROW_EXPORT FlightServerBase { /// \brief Run an insecure server on localhost at the indicated port. Block /// until server is shut down or otherwise terminates - /// \param[in] port - /// \return Status + /// \param[in] port the port to bind to void Run(int port); /// \brief Shut down the server. Can be called from signal handler or another @@ -125,7 +132,10 @@ class ARROW_EXPORT FlightServerBase { /// \return Status virtual Status DoGet(const Ticket& request, std::unique_ptr* stream); - // virtual Status DoPut(std::unique_ptr* reader) = 0; + /// \brief Process a stream of IPC payloads sent from a client + /// \param[in] reader a sequence of uploaded record batches + /// \return Status + virtual Status DoPut(std::unique_ptr reader); /// \brief Execute an action, return stream of zero or more results /// \param[in] action the action to execute, with type and body diff --git a/cpp/src/arrow/flight/test-integration-client.cc b/cpp/src/arrow/flight/test-integration-client.cc index 267025a451cc7..62522833f4ba3 100644 --- a/cpp/src/arrow/flight/test-integration-client.cc +++ b/cpp/src/arrow/flight/test-integration-client.cc @@ -15,12 +15,11 @@ // specific language governing permissions and limitations // under the License. -// Client implementation for Flight integration testing. Requests the given -// path from the Flight server, which reads that file and sends it as a stream -// to the client. The client writes the server stream to the IPC file format at -// the given output file path. The integration test script then uses the -// existing integration test tools to compare the output binary with the -// original JSON +// Client implementation for Flight integration testing. Loads +// RecordBatches from the given JSON file and uploads them to the +// Flight server, which stores the data and schema in memory. The +// client then requests the data from the server and compares it to +// the data originally uploaded. #include #include @@ -31,6 +30,7 @@ #include "arrow/io/test-common.h" #include "arrow/ipc/json.h" #include "arrow/record_batch.h" +#include "arrow/table.h" #include "arrow/flight/server.h" #include "arrow/flight/test-util.h" @@ -38,7 +38,60 @@ DEFINE_string(host, "localhost", "Server port to connect to"); DEFINE_int32(port, 31337, "Server port to connect to"); DEFINE_string(path, "", "Resource path to request"); -DEFINE_string(output, "", "Where to write requested resource"); + +/// \brief Helper to read a RecordBatchReader into a Table. +arrow::Status ReadToTable(std::unique_ptr& reader, + std::shared_ptr* retrieved_data) { + std::vector> retrieved_chunks; + std::shared_ptr chunk; + while (true) { + RETURN_NOT_OK(reader->ReadNext(&chunk)); + if (chunk == nullptr) break; + retrieved_chunks.push_back(chunk); + } + return arrow::Table::FromRecordBatches(reader->schema(), retrieved_chunks, + retrieved_data); +} + +/// \brief Helper to read a JsonReader into a Table. +arrow::Status ReadToTable(std::unique_ptr& reader, + std::shared_ptr* retrieved_data) { + std::vector> retrieved_chunks; + std::shared_ptr chunk; + for (int i = 0; i < reader->num_record_batches(); i++) { + RETURN_NOT_OK(reader->ReadRecordBatch(i, &chunk)); + retrieved_chunks.push_back(chunk); + } + return arrow::Table::FromRecordBatches(reader->schema(), retrieved_chunks, + retrieved_data); +} + +/// \brief Helper to copy a RecordBatchReader to a RecordBatchWriter. +arrow::Status CopyReaderToWriter(std::unique_ptr& reader, + arrow::ipc::RecordBatchWriter& writer) { + while (true) { + std::shared_ptr chunk; + RETURN_NOT_OK(reader->ReadNext(&chunk)); + if (chunk == nullptr) break; + RETURN_NOT_OK(writer.WriteRecordBatch(*chunk)); + } + return writer.Close(); +} + +/// \brief Helper to read a flight into a Table. +arrow::Status ConsumeFlightLocation(const arrow::flight::Location& location, + const arrow::flight::Ticket& ticket, + const std::shared_ptr& schema, + std::shared_ptr* retrieved_data) { + std::unique_ptr read_client; + RETURN_NOT_OK( + arrow::flight::FlightClient::Connect(location.host, location.port, &read_client)); + + std::unique_ptr stream; + RETURN_NOT_OK(read_client->DoGet(ticket, schema, &stream)); + + return ReadToTable(stream, retrieved_data); +} int main(int argc, char** argv) { gflags::SetUsageMessage("Integration testing client for Flight."); @@ -49,6 +102,25 @@ int main(int argc, char** argv) { arrow::flight::FlightDescriptor descr{ arrow::flight::FlightDescriptor::PATH, "", {FLAGS_path}}; + + // 1. Put the data to the server. + std::unique_ptr reader; + std::cout << "Opening JSON file '" << FLAGS_path << "'" << std::endl; + std::shared_ptr in_file; + ABORT_NOT_OK(arrow::io::ReadableFile::Open(FLAGS_path, &in_file)); + ABORT_NOT_OK(arrow::ipc::internal::json::JsonReader::Open(arrow::default_memory_pool(), + in_file, &reader)); + + std::shared_ptr original_data; + ABORT_NOT_OK(ReadToTable(reader, &original_data)); + + std::unique_ptr write_stream; + ABORT_NOT_OK(client->DoPut(descr, reader->schema(), &write_stream)); + std::unique_ptr table_reader( + new arrow::TableBatchReader(*original_data)); + ABORT_NOT_OK(CopyReaderToWriter(table_reader, *write_stream)); + + // 2. Get the ticket for the data. std::unique_ptr info; ABORT_NOT_OK(client->GetFlightInfo(descr, &info)); @@ -60,23 +132,27 @@ int main(int argc, char** argv) { return -1; } - arrow::flight::Ticket ticket = info->endpoints()[0].ticket; - std::unique_ptr stream; - ABORT_NOT_OK(client->DoGet(ticket, schema, &stream)); - - std::shared_ptr out_file; - ABORT_NOT_OK(arrow::io::FileOutputStream::Open(FLAGS_output, &out_file)); - std::shared_ptr writer; - ABORT_NOT_OK(arrow::ipc::RecordBatchFileWriter::Open(out_file.get(), schema, &writer)); - - std::shared_ptr chunk; - while (true) { - ABORT_NOT_OK(stream->ReadNext(&chunk)); - if (chunk == nullptr) break; - ABORT_NOT_OK(writer->WriteRecordBatch(*chunk)); + for (const arrow::flight::FlightEndpoint& endpoint : info->endpoints()) { + const auto& ticket = endpoint.ticket; + + auto locations = endpoint.locations; + if (locations.size() == 0) { + locations = {arrow::flight::Location{FLAGS_host, FLAGS_port}}; + } + + for (const auto location : locations) { + std::cout << "Verifying location " << location.host << ':' << location.port + << std::endl; + // 3. Download the data from the server. + std::shared_ptr retrieved_data; + ABORT_NOT_OK(ConsumeFlightLocation(location, ticket, schema, &retrieved_data)); + + // 4. Validate that the data is equal. + if (!original_data->Equals(*retrieved_data)) { + std::cerr << "Data does not match!" << std::endl; + return 1; + } + } } - - ABORT_NOT_OK(writer->Close()); - return 0; } diff --git a/cpp/src/arrow/flight/test-integration-server.cc b/cpp/src/arrow/flight/test-integration-server.cc index 80813e7f19a4c..7e201a031943d 100644 --- a/cpp/src/arrow/flight/test-integration-server.cc +++ b/cpp/src/arrow/flight/test-integration-server.cc @@ -27,6 +27,7 @@ #include "arrow/io/test-common.h" #include "arrow/ipc/json.h" #include "arrow/record_batch.h" +#include "arrow/table.h" #include "arrow/flight/server.h" #include "arrow/flight/test-util.h" @@ -36,57 +37,7 @@ DEFINE_int32(port, 31337, "Server port to listen on"); namespace arrow { namespace flight { -class JsonReaderRecordBatchStream : public FlightDataStream { - public: - explicit JsonReaderRecordBatchStream( - std::unique_ptr&& reader) - : index_(0), pool_(default_memory_pool()), reader_(std::move(reader)) {} - - std::shared_ptr schema() override { return reader_->schema(); } - - Status Next(ipc::internal::IpcPayload* payload) override { - if (index_ >= reader_->num_record_batches()) { - // Signal that iteration is over - payload->metadata = nullptr; - return Status::OK(); - } - - std::shared_ptr batch; - RETURN_NOT_OK(reader_->ReadRecordBatch(index_, &batch)); - index_++; - - if (!batch) { - // Signal that iteration is over - payload->metadata = nullptr; - return Status::OK(); - } else { - return ipc::internal::GetRecordBatchPayload(*batch, pool_, payload); - } - } - - private: - int index_; - MemoryPool* pool_; - std::unique_ptr reader_; -}; - class FlightIntegrationTestServer : public FlightServerBase { - Status ReadJson(const std::string& json_path, - std::unique_ptr* out) { - std::shared_ptr in_file; - std::cout << "Opening JSON file '" << json_path << "'" << std::endl; - RETURN_NOT_OK(io::ReadableFile::Open(json_path, &in_file)); - - int64_t file_size = 0; - RETURN_NOT_OK(in_file->GetSize(&file_size)); - - std::shared_ptr json_buffer; - RETURN_NOT_OK(in_file->Read(file_size, &json_buffer)); - - RETURN_NOT_OK(arrow::ipc::internal::json::JsonReader::Open(json_buffer, out)); - return Status::OK(); - } - Status GetFlightInfo(const FlightDescriptor& request, std::unique_ptr* info) override { if (request.type == FlightDescriptor::PATH) { @@ -94,16 +45,19 @@ class FlightIntegrationTestServer : public FlightServerBase { return Status::Invalid("Invalid path"); } - std::unique_ptr reader; - RETURN_NOT_OK(ReadJson(request.path.back(), &reader)); + auto data = uploaded_chunks.find(request.path[0]); + if (data == uploaded_chunks.end()) { + return Status::KeyError("Could not find flight.", request.path[0]); + } + auto flight = data->second; - FlightEndpoint endpoint1({{request.path.back()}, {}}); + FlightEndpoint endpoint1({{request.path[0]}, {}}); FlightInfo::Data flight_data; - RETURN_NOT_OK(internal::SchemaToString(*reader->schema(), &flight_data.schema)); + RETURN_NOT_OK(internal::SchemaToString(*flight->schema(), &flight_data.schema)); flight_data.descriptor = request; flight_data.endpoints = {endpoint1}; - flight_data.total_records = reader->num_record_batches(); + flight_data.total_records = flight->num_rows(); flight_data.total_bytes = -1; FlightInfo value(flight_data); @@ -116,14 +70,44 @@ class FlightIntegrationTestServer : public FlightServerBase { Status DoGet(const Ticket& request, std::unique_ptr* data_stream) override { - std::unique_ptr reader; - RETURN_NOT_OK(ReadJson(request.ticket, &reader)); + auto data = uploaded_chunks.find(request.ticket); + if (data == uploaded_chunks.end()) { + return Status::KeyError("Could not find flight.", request.ticket); + } + auto flight = data->second; - *data_stream = std::unique_ptr( - new JsonReaderRecordBatchStream(std::move(reader))); + *data_stream = std::unique_ptr(new RecordBatchStream( + std::shared_ptr(new TableBatchReader(*flight)))); return Status::OK(); } + + Status DoPut(std::unique_ptr reader) override { + const FlightDescriptor& descriptor = reader->descriptor(); + + if (descriptor.type != FlightDescriptor::DescriptorType::PATH) { + return Status::Invalid("Must specify a path"); + } else if (descriptor.path.size() < 1) { + return Status::Invalid("Must specify a path"); + } + + std::string key = descriptor.path[0]; + + std::vector> retrieved_chunks; + std::shared_ptr chunk; + while (true) { + RETURN_NOT_OK(reader->ReadNext(&chunk)); + if (chunk == nullptr) break; + retrieved_chunks.push_back(chunk); + } + std::shared_ptr retrieved_data; + RETURN_NOT_OK(arrow::Table::FromRecordBatches(reader->schema(), retrieved_chunks, + &retrieved_data)); + uploaded_chunks[key] = retrieved_data; + return Status::OK(); + } + + std::unordered_map> uploaded_chunks; }; } // namespace flight diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index 0362105bbc592..e4251bdd5d21b 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -152,12 +152,6 @@ class FlightInfo { mutable bool reconstructed_schema_; }; -// TODO(wesm): NYI -class ARROW_EXPORT FlightPutWriter { - public: - virtual ~FlightPutWriter() = default; -}; - /// \brief An iterator to FlightInfo instances returned by ListFlights class ARROW_EXPORT FlightListing { public: diff --git a/cpp/src/arrow/ipc/json.cc b/cpp/src/arrow/ipc/json.cc index 61c242ca2dbbb..56fe31cbf5691 100644 --- a/cpp/src/arrow/ipc/json.cc +++ b/cpp/src/arrow/ipc/json.cc @@ -22,6 +22,7 @@ #include #include "arrow/buffer.h" +#include "arrow/io/file.h" #include "arrow/ipc/json-internal.h" #include "arrow/memory_pool.h" #include "arrow/record_batch.h" @@ -157,6 +158,18 @@ Status JsonReader::Open(MemoryPool* pool, const std::shared_ptr& data, return (*reader)->impl_->ParseAndReadSchema(); } +Status JsonReader::Open(MemoryPool* pool, + const std::shared_ptr& in_file, + std::unique_ptr* reader) { + int64_t file_size = 0; + RETURN_NOT_OK(in_file->GetSize(&file_size)); + + std::shared_ptr json_buffer; + RETURN_NOT_OK(in_file->Read(file_size, &json_buffer)); + + return Open(pool, json_buffer, reader); +} + std::shared_ptr JsonReader::schema() const { return impl_->schema(); } int JsonReader::num_record_batches() const { return impl_->num_record_batches(); } diff --git a/cpp/src/arrow/ipc/json.h b/cpp/src/arrow/ipc/json.h index 5c00555de8ec0..aeed7070fe96e 100644 --- a/cpp/src/arrow/ipc/json.h +++ b/cpp/src/arrow/ipc/json.h @@ -33,6 +33,10 @@ class MemoryPool; class RecordBatch; class Schema; +namespace io { +class ReadableFile; +} // namespace io + namespace ipc { namespace internal { namespace json { @@ -95,6 +99,15 @@ class ARROW_EXPORT JsonReader { static Status Open(const std::shared_ptr& data, std::unique_ptr* reader); + /// \brief Create a new JSON reader from a file + /// + /// \param[in] pool a MemoryPool to use for buffer allocations + /// \param[in] in_file a ReadableFile containing JSON data + /// \param[out] reader the returned reader object + /// \return Status + static Status Open(MemoryPool* pool, const std::shared_ptr& in_file, + std::unique_ptr* reader); + /// \brief Return the schema read from the JSON std::shared_ptr schema() const; diff --git a/integration/integration_test.py b/integration/integration_test.py index 0bced26f15acd..e7e8edda6ddf1 100644 --- a/integration/integration_test.py +++ b/integration/integration_test.py @@ -1004,28 +1004,15 @@ def _compare_flight_implementations(self, producer, consumer): ) print('##########################################################') - for json_path in self.json_files: - print('==========================================================') - print('Testing file {0}'.format(json_path)) - print('==========================================================') - - name = os.path.splitext(os.path.basename(json_path))[0] + with producer.flight_server(): + for json_path in self.json_files: + print('=' * 58) + print('Testing file {0}'.format(json_path)) + print('=' * 58) - file_id = guid()[:8] - - with producer.flight_server(): - # Have the client request the file - consumer_file_path = os.path.join( - self.temp_dir, - file_id + '_' + name + '.consumer_requested_file') - consumer.flight_request(producer.FLIGHT_PORT, - json_path, consumer_file_path) - - # Validate the file - print('-- Validating file') - consumer.validate(json_path, consumer_file_path) - - # TODO: also have the client upload the file + # Have the client upload the file, then download and + # compare + consumer.flight_request(producer.FLIGHT_PORT, json_path) class Tester(object): @@ -1053,7 +1040,7 @@ def validate(self, json_path, arrow_path): def flight_server(self): raise NotImplementedError - def flight_request(self, port, json_path, arrow_path): + def flight_request(self, port, json_path): raise NotImplementedError @@ -1122,12 +1109,11 @@ def file_to_stream(self, file_path, stream_path): print(' '.join(cmd)) run_cmd(cmd) - def flight_request(self, port, json_path, arrow_path): + def flight_request(self, port, json_path): cmd = ['java', '-cp', self.ARROW_FLIGHT_JAR, self.ARROW_FLIGHT_CLIENT, '-port', str(port), - '-j', json_path, - '-a', arrow_path] + '-j', json_path] if self.debug: print(' '.join(cmd)) run_cmd(cmd) @@ -1230,15 +1216,14 @@ def flight_server(self): server.terminate() server.wait(5) - def flight_request(self, port, json_path, arrow_path): + def flight_request(self, port, json_path): cmd = self.FLIGHT_CLIENT_CMD + [ '-port=' + str(port), '-path=' + json_path, - '-output=' + arrow_path ] if self.debug: print(' '.join(cmd)) - subprocess.run(cmd) + run_cmd(cmd) class JSTester(Tester): diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java index ad7c7e28da242..bd126b5ea203c 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java @@ -127,7 +127,9 @@ public ClientStreamListener startPut(FlightDescriptor descriptor, VectorSchemaRo // send the schema to start. ArrowMessage message = new ArrowMessage(descriptor.toProtocol(), root.getSchema()); observer.onNext(message); - return new PutObserver(new VectorUnloader(root, true, false), observer, resultObserver.getFuture()); + return new PutObserver(new VectorUnloader( + root, true /* include # of nulls in vectors */, true /* must align buffers to be C++-compatible */), + observer, resultObserver.getFuture()); } public FlightInfo getInfo(FlightDescriptor descriptor) { @@ -211,7 +213,8 @@ public PutObserver(VectorUnloader unloader, ClientCallStreamObserver listener) { listener.onNext(new ActionType("get", "pull a stream. Action must be done via standard get mechanism")); listener.onNext(new ActionType("put", "push a stream. Action must be done via standard get mechanism")); listener.onNext(new ActionType("drop", "delete a flight. Action body is a JSON encoded path.")); + listener.onCompleted(); } @Override diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java index 803a56c6c1afe..ed450074a767a 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java @@ -18,8 +18,8 @@ package org.apache.arrow.flight.example.integration; import java.io.File; -import java.io.FileOutputStream; import java.io.IOException; +import java.util.Collections; import java.util.List; import org.apache.arrow.flight.FlightClient; @@ -30,9 +30,11 @@ import org.apache.arrow.flight.Location; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorLoader; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.dictionary.DictionaryProvider; -import org.apache.arrow.vector.ipc.ArrowFileWriter; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.JsonFileReader; +import org.apache.arrow.vector.util.Validator; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.CommandLineParser; import org.apache.commons.cli.DefaultParser; @@ -48,10 +50,9 @@ class IntegrationTestClient { private IntegrationTestClient() { options = new Options(); - options.addOption("a", "arrow", true, "arrow file"); options.addOption("j", "json", true, "json file"); options.addOption("host", true, "The host to connect to."); - options.addOption("port", true, "The port to connect to." ); + options.addOption("port", true, "The port to connect to."); } public static void main(String[] args) { @@ -64,7 +65,7 @@ public static void main(String[] args) { } } - static void fatalError(String message, Throwable e) { + private static void fatalError(String message, Throwable e) { System.err.println(message); System.err.println(e.getMessage()); LOGGER.error(message, e); @@ -72,36 +73,65 @@ static void fatalError(String message, Throwable e) { } private void run(String[] args) throws ParseException, IOException { - CommandLineParser parser = new DefaultParser(); - CommandLine cmd = parser.parse(options, args, false); - - String fileName = cmd.getOptionValue("arrow"); - if (fileName == null) { - throw new IllegalArgumentException("missing arrow file parameter"); - } - File arrowFile = new File(fileName); - if (arrowFile.exists()) { - throw new IllegalArgumentException("arrow file already exists: " + arrowFile.getAbsolutePath()); - } + final CommandLineParser parser = new DefaultParser(); + final CommandLine cmd = parser.parse(options, args, false); final String host = cmd.getOptionValue("host", "localhost"); final int port = Integer.parseInt(cmd.getOptionValue("port", "31337")); final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); - FlightClient client = new FlightClient(allocator, new Location(host, port)); - FlightInfo info = client.getInfo(FlightDescriptor.path(cmd.getOptionValue("json"))); + final FlightClient client = new FlightClient(allocator, new Location(host, port)); + + final String inputPath = cmd.getOptionValue("j"); + + // 1. Read data from JSON and upload to server. + FlightDescriptor descriptor = FlightDescriptor.path(inputPath); + VectorSchemaRoot jsonRoot; + try (JsonFileReader reader = new JsonFileReader(new File(inputPath), allocator); + VectorSchemaRoot root = VectorSchemaRoot.create(reader.start(), allocator)) { + jsonRoot = VectorSchemaRoot.create(root.getSchema(), allocator); + VectorUnloader unloader = new VectorUnloader(root); + VectorLoader jsonLoader = new VectorLoader(jsonRoot); + FlightClient.ClientStreamListener stream = client.startPut(descriptor, root); + while (reader.read(root)) { + stream.putNext(); + jsonLoader.load(unloader.getRecordBatch()); + root.clear(); + } + stream.completed(); + // Need to call this, or exceptions from the server get swallowed + stream.getResult(); + } + + // 2. Get the ticket for the data. + FlightInfo info = client.getInfo(descriptor); List endpoints = info.getEndpoints(); if (endpoints.isEmpty()) { throw new RuntimeException("No endpoints returned from Flight server."); } - FlightStream stream = client.getStream(info.getEndpoints().get(0).getTicket()); - try (VectorSchemaRoot root = stream.getRoot(); - FileOutputStream fileOutputStream = new FileOutputStream(arrowFile); - ArrowFileWriter arrowWriter = new ArrowFileWriter(root, new DictionaryProvider.MapDictionaryProvider(), - fileOutputStream.getChannel())) { - while (stream.next()) { - arrowWriter.writeBatch(); + for (FlightEndpoint endpoint : info.getEndpoints()) { + // 3. Download the data from the server. + List locations = endpoint.getLocations(); + if (locations.size() == 0) { + locations = Collections.singletonList(new Location(host, port)); + } + for (Location location : locations) { + System.out.println("Verifying location " + location.getHost() + ":" + location.getPort()); + FlightClient readClient = new FlightClient(allocator, location); + FlightStream stream = readClient.getStream(endpoint.getTicket()); + VectorSchemaRoot downloadedRoot; + try (VectorSchemaRoot root = stream.getRoot()) { + downloadedRoot = VectorSchemaRoot.create(root.getSchema(), allocator); + VectorLoader loader = new VectorLoader(downloadedRoot); + VectorUnloader unloader = new VectorUnloader(root); + while (stream.next()) { + loader.load(unloader.getRecordBatch()); + } + } + + // 4. Validate the data. + Validator.compareVectorSchemaRoot(jsonRoot, downloadedRoot); } } } diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java index 7b45e53a149be..eff2f5d4126cc 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java @@ -17,30 +17,11 @@ package org.apache.arrow.flight.example.integration; -import java.io.File; -import java.nio.charset.StandardCharsets; -import java.util.Collections; -import java.util.concurrent.Callable; - -import org.apache.arrow.flight.Action; -import org.apache.arrow.flight.ActionType; -import org.apache.arrow.flight.Criteria; -import org.apache.arrow.flight.FlightDescriptor; -import org.apache.arrow.flight.FlightEndpoint; -import org.apache.arrow.flight.FlightInfo; -import org.apache.arrow.flight.FlightProducer; -import org.apache.arrow.flight.FlightServer; -import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.Location; -import org.apache.arrow.flight.Result; -import org.apache.arrow.flight.Ticket; -import org.apache.arrow.flight.auth.ServerAuthHandler; -import org.apache.arrow.flight.impl.Flight; +import org.apache.arrow.flight.example.ExampleFlightServer; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.ipc.JsonFileReader; -import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.util.AutoCloseables; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.CommandLineParser; import org.apache.commons.cli.DefaultParser; @@ -48,6 +29,7 @@ import org.apache.commons.cli.ParseException; class IntegrationTestServer { + private static final org.slf4j.Logger LOGGER = org.slf4j.LoggerFactory.getLogger(IntegrationTestServer.class); private final Options options; private IntegrationTestServer() { @@ -58,17 +40,25 @@ private IntegrationTestServer() { private void run(String[] args) throws Exception { CommandLineParser parser = new DefaultParser(); CommandLine cmd = parser.parse(options, args, false); + final int port = Integer.parseInt(cmd.getOptionValue("port", "31337")); final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); - final int port = Integer.parseInt(cmd.getOptionValue("port", "31337")); - try (final IntegrationFlightProducer producer = new IntegrationFlightProducer(allocator); - final FlightServer server = new FlightServer(allocator, port, producer, ServerAuthHandler.NO_OP)) { - server.start(); - // Print out message for integration test script - System.out.println("Server listening on localhost:" + server.getPort()); - while (true) { - Thread.sleep(30000); + final ExampleFlightServer efs = new ExampleFlightServer(allocator, new Location("localhost", port)); + efs.start(); + // Print out message for integration test script + System.out.println("Server listening on localhost:" + port); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + try { + System.out.println("\nExiting..."); + AutoCloseables.close(efs, allocator); + } catch (Exception e) { + e.printStackTrace(); } + })); + + while (true) { + Thread.sleep(30000); } } @@ -76,81 +66,17 @@ public static void main(String[] args) { try { new IntegrationTestServer().run(args); } catch (ParseException e) { - IntegrationTestClient.fatalError("Error parsing arguments", e); + fatalError("Error parsing arguments", e); } catch (Exception e) { - IntegrationTestClient.fatalError("Runtime error", e); + fatalError("Runtime error", e); } } - static class IntegrationFlightProducer implements FlightProducer, AutoCloseable { - private final BufferAllocator allocator; - - IntegrationFlightProducer(BufferAllocator allocator) { - this.allocator = allocator; - } - - @Override - public void close() { - allocator.close(); - } - - @Override - public void getStream(Ticket ticket, ServerStreamListener listener) { - String path = new String(ticket.getBytes(), StandardCharsets.UTF_8); - File inputFile = new File(path); - try (JsonFileReader reader = new JsonFileReader(inputFile, allocator)) { - Schema schema = reader.start(); - try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { - listener.start(root); - while (reader.read(root)) { - listener.putNext(); - } - listener.completed(); - } - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public void listFlights(Criteria criteria, StreamListener listener) { - listener.onCompleted(); - } - - @Override - public FlightInfo getFlightInfo(FlightDescriptor descriptor) { - if (descriptor.isCommand()) { - throw new UnsupportedOperationException("Commands not supported."); - } - if (descriptor.getPath().size() < 1) { - throw new IllegalArgumentException("Must provide a path."); - } - String path = descriptor.getPath().get(0); - File inputFile = new File(path); - try (JsonFileReader reader = new JsonFileReader(inputFile, allocator)) { - Schema schema = reader.start(); - return new FlightInfo(schema, descriptor, - Collections.singletonList(new FlightEndpoint(new Ticket(path.getBytes()), - new Location("localhost", 31338))), - 0, 0); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public Callable acceptPut(FlightStream flightStream) { - return null; - } - - @Override - public Result doAction(Action action) { - return null; - } - - @Override - public void listActions(StreamListener listener) { - listener.onCompleted(); - } + private static void fatalError(String message, Throwable e) { + System.err.println(message); + System.err.println(e.getMessage()); + LOGGER.error(message, e); + System.exit(1); } + }