Skip to content

Commit

Permalink
Introduce FlightPutWriter
Browse files Browse the repository at this point in the history
  • Loading branch information
David Li authored and David Li committed Feb 5, 2019
1 parent 58d6936 commit 6edf2e2
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 44 deletions.
41 changes: 27 additions & 14 deletions cpp/src/arrow/flight/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,13 @@ class FlightClient;

/// \brief A RecordBatchWriter implementation that writes to a Flight
/// DoPut stream.
class FlightStreamWriter : public ipc::RecordBatchWriter {
class FlightPutWriter::FlightPutWriterImpl : public ipc::RecordBatchWriter {
public:
explicit FlightStreamWriter(std::unique_ptr<ClientRpc> rpc,
const FlightDescriptor& descriptor,
const std::shared_ptr<Schema>& schema,
MemoryPool* pool = default_memory_pool())
: rpc_(std::move(rpc)),
descriptor_(descriptor),
schema_(schema),
pool_(pool) {}
explicit FlightPutWriterImpl(std::unique_ptr<ClientRpc> rpc,
const FlightDescriptor& descriptor,
const std::shared_ptr<Schema>& schema,
MemoryPool* pool = default_memory_pool())
: rpc_(std::move(rpc)), descriptor_(descriptor), schema_(schema), pool_(pool) {}

Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override {
IpcPayload payload;
Expand Down Expand Up @@ -171,6 +168,22 @@ class FlightStreamWriter : public ipc::RecordBatchWriter {
friend class FlightClient;
};

FlightPutWriter::~FlightPutWriter() {}

FlightPutWriter::FlightPutWriter(std::unique_ptr<FlightPutWriterImpl> impl) {
impl_ = std::move(impl);
}

Status FlightPutWriter::WriteRecordBatch(const RecordBatch& batch, bool allow_64bit) {
return impl_->WriteRecordBatch(batch, allow_64bit);
}

Status FlightPutWriter::Close() { return impl_->Close(); }

void FlightPutWriter::set_memory_pool(MemoryPool* pool) {
return impl_->set_memory_pool(pool);
}

class FlightClient::FlightClientImpl {
public:
Status Connect(const std::string& host, int port) {
Expand Down Expand Up @@ -277,10 +290,10 @@ class FlightClient::FlightClientImpl {
}

Status DoPut(const FlightDescriptor& descriptor, const std::shared_ptr<Schema>& schema,
std::unique_ptr<ipc::RecordBatchWriter>* stream) {
std::unique_ptr<FlightPutWriter>* stream) {
std::unique_ptr<ClientRpc> rpc(new ClientRpc);
std::unique_ptr<FlightStreamWriter> out(
new FlightStreamWriter(std::move(rpc), descriptor, schema));
std::unique_ptr<FlightPutWriter::FlightPutWriterImpl> out(
new FlightPutWriter::FlightPutWriterImpl(std::move(rpc), descriptor, schema));
std::unique_ptr<grpc::ClientWriter<pb::FlightData>> write_stream(
stub_->DoPut(&out->rpc_->context, &out->response));

Expand All @@ -302,7 +315,7 @@ class FlightClient::FlightClientImpl {
}

out->set_stream(std::move(write_stream));
*stream = std::move(out);
*stream = std::unique_ptr<FlightPutWriter>(new FlightPutWriter(std::move(out)));
return Status::OK();
}

Expand Down Expand Up @@ -350,7 +363,7 @@ Status FlightClient::DoGet(const Ticket& ticket, const std::shared_ptr<Schema>&

Status FlightClient::DoPut(const FlightDescriptor& descriptor,
const std::shared_ptr<Schema>& schema,
std::unique_ptr<ipc::RecordBatchWriter>* stream) {
std::unique_ptr<FlightPutWriter>* stream) {
return impl_->DoPut(descriptor, schema, stream);
}

Expand Down
28 changes: 21 additions & 7 deletions cpp/src/arrow/flight/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <string>
#include <vector>

#include "arrow/ipc/writer.h"
#include "arrow/status.h"
#include "arrow/util/visibility.h"

Expand All @@ -35,14 +36,10 @@ class RecordBatch;
class RecordBatchReader;
class Schema;

namespace ipc {

class RecordBatchWriter;

} // namespace ipc

namespace flight {

class FlightPutWriter;

/// \brief Client class for Arrow Flight RPC services (gRPC-based).
/// API experimental for now
class ARROW_EXPORT FlightClient {
Expand Down Expand Up @@ -108,13 +105,30 @@ class ARROW_EXPORT FlightClient {
/// \param[out] stream a writer to write record batches to
/// \return Status
Status DoPut(const FlightDescriptor& descriptor, const std::shared_ptr<Schema>& schema,
std::unique_ptr<ipc::RecordBatchWriter>* stream);
std::unique_ptr<FlightPutWriter>* stream);

private:
FlightClient();
class FlightClientImpl;
std::unique_ptr<FlightClientImpl> impl_;
};

/// \brief An interface to upload record batches to a Flight server
class ARROW_EXPORT FlightPutWriter : public ipc::RecordBatchWriter {
public:
~FlightPutWriter();

Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override;
Status Close() override;
void set_memory_pool(MemoryPool* pool) override;

private:
class FlightPutWriterImpl;
explicit FlightPutWriter(std::unique_ptr<FlightPutWriterImpl> impl);
std::unique_ptr<FlightPutWriterImpl> impl_;

friend class FlightClient;
};

} // namespace flight
} // namespace arrow
3 changes: 1 addition & 2 deletions cpp/src/arrow/flight/internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ void ToProto(const Ticket& ticket, pb::Ticket* pb_ticket) {

// FlightData

Status FromProto(const pb::FlightData& pb_data,
FlightDescriptor* descriptor,
Status FromProto(const pb::FlightData& pb_data, FlightDescriptor* descriptor,
std::unique_ptr<ipc::Message>* message) {
RETURN_NOT_OK(internal::FromProto(pb_data.flight_descriptor(), descriptor));
const std::string& header = pb_data.data_header();
Expand Down
3 changes: 1 addition & 2 deletions cpp/src/arrow/flight/internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ Status FromProto(const pb::Result& pb_result, Result* result);
Status FromProto(const pb::Criteria& pb_criteria, Criteria* criteria);
Status FromProto(const pb::Location& pb_location, Location* location);
Status FromProto(const pb::Ticket& pb_ticket, Ticket* ticket);
Status FromProto(const pb::FlightData& pb_data,
FlightDescriptor* descriptor,
Status FromProto(const pb::FlightData& pb_data, FlightDescriptor* descriptor,
std::unique_ptr<ipc::Message>* message);
Status FromProto(const pb::FlightDescriptor& pb_descr, FlightDescriptor* descr);
Status FromProto(const pb::FlightEndpoint& pb_endpoint, FlightEndpoint* endpoint);
Expand Down
26 changes: 12 additions & 14 deletions cpp/src/arrow/flight/serialization-internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,7 @@ inline Status FailSerialization(Status status) {

inline arrow::Status FailSerialization(arrow::Status status) {
if (!status.ok()) {
ARROW_LOG(WARNING) << "Error deserializing Flight message: "
<< status.ToString();
ARROW_LOG(WARNING) << "Error deserializing Flight message: " << status.ToString();
}
return status;
}
Expand All @@ -180,9 +179,8 @@ template <>
class SerializationTraits<FlightData> {
public:
static Status Serialize(const FlightData& msg, ByteBuffer** buffer, bool* own_buffer) {
return FailSerialization(
Status(StatusCode::UNIMPLEMENTED,
"internal::FlightData serialization not implemented"));
return FailSerialization(Status(
StatusCode::UNIMPLEMENTED, "internal::FlightData serialization not implemented"));
}

static Status Deserialize(ByteBuffer* buffer, FlightData* out) {
Expand All @@ -207,20 +205,20 @@ class SerializationTraits<FlightData> {
case pb::FlightData::kFlightDescriptorFieldNumber: {
pb::FlightDescriptor pb_descriptor;
if (!pb_descriptor.ParseFromCodedStream(&pb_stream)) {
return FailSerialization(Status(StatusCode::INTERNAL,
"Unable to parse FlightDescriptor"));
return FailSerialization(
Status(StatusCode::INTERNAL, "Unable to parse FlightDescriptor"));
}
} break;
case pb::FlightData::kDataHeaderFieldNumber: {
if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->metadata)) {
return FailSerialization(Status(StatusCode::INTERNAL,
"Unable to read FlightData metadata"));
return FailSerialization(
Status(StatusCode::INTERNAL, "Unable to read FlightData metadata"));
}
} break;
case pb::FlightData::kDataBodyFieldNumber: {
if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->body)) {
return FailSerialization(Status(StatusCode::INTERNAL,
"Unable to read FlightData body"));
return FailSerialization(
Status(StatusCode::INTERNAL, "Unable to read FlightData body"));
}
} break;
default:
Expand Down Expand Up @@ -278,9 +276,9 @@ class SerializationTraits<IpcPayload> {

// TODO(wesm): messages over 2GB unlikely to be yet supported
if (total_size > kInt32Max) {
return FailSerialization(grpc::Status(
grpc::StatusCode::INVALID_ARGUMENT,
"Cannot send record batches exceeding 2GB yet"));
return FailSerialization(
grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
"Cannot send record batches exceeding 2GB yet"));
}

// Allocate slice, assign to output buffer
Expand Down
10 changes: 5 additions & 5 deletions cpp/src/arrow/flight/test-integration-client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,14 @@ arrow::Status ReadToTable(std::unique_ptr<arrow::ipc::internal::json::JsonReader

/// \brief Helper to copy a RecordBatchReader to a RecordBatchWriter.
arrow::Status CopyReaderToWriter(std::unique_ptr<arrow::RecordBatchReader>& reader,
std::unique_ptr<arrow::ipc::RecordBatchWriter>& writer) {
arrow::ipc::RecordBatchWriter& writer) {
while (true) {
std::shared_ptr<arrow::RecordBatch> chunk;
RETURN_NOT_OK(reader->ReadNext(&chunk));
if (chunk == nullptr) break;
RETURN_NOT_OK(writer->WriteRecordBatch(*chunk));
RETURN_NOT_OK(writer.WriteRecordBatch(*chunk));
}
return writer->Close();
return writer.Close();
}

/// \brief Helper to read a flight into a Table.
Expand Down Expand Up @@ -114,11 +114,11 @@ int main(int argc, char** argv) {
std::shared_ptr<arrow::Table> original_data;
ABORT_NOT_OK(ReadToTable(reader, &original_data));

std::unique_ptr<arrow::ipc::RecordBatchWriter> write_stream;
std::unique_ptr<arrow::flight::FlightPutWriter> write_stream;
ABORT_NOT_OK(client->DoPut(descr, reader->schema(), &write_stream));
std::unique_ptr<arrow::RecordBatchReader> table_reader(
new arrow::TableBatchReader(*original_data));
ABORT_NOT_OK(CopyReaderToWriter(table_reader, write_stream));
ABORT_NOT_OK(CopyReaderToWriter(table_reader, *write_stream));

// 2. Get the ticket for the data.
std::unique_ptr<arrow::flight::FlightInfo> info;
Expand Down

0 comments on commit 6edf2e2

Please sign in to comment.