Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-4626: [Flight] Add application-defined metadata to DoGet/DoPut #4282

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 179 additions & 28 deletions cpp/src/arrow/flight/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ struct ClientRpc {
if (auth_handler) {
std::string token;
RETURN_NOT_OK(auth_handler->GetToken(&token));
context.AddMetadata(internal::AUTH_HEADER, token);
context.AddMetadata(internal::kGrpcAuthHeader, token);
}
return Status::OK();
}
Expand Down Expand Up @@ -129,29 +129,64 @@ class GrpcClientAuthReader : public ClientAuthReader {
stream_;
};

class FlightIpcMessageReader : public ipc::MessageReader {
// The next two classes are intertwined. To get the application
// metadata while avoiding reimplementing RecordBatchStreamReader, we
// create an ipc::MessageReader that is tied to the
// MetadataRecordBatchReader. Every time an IPC message is read, it updates
// the application metadata field of the MetadataRecordBatchReader. The
// MetadataRecordBatchReader wraps RecordBatchStreamReader, offering an
// additional method to get both the record batch and application
// metadata.

class GrpcIpcMessageReader;
class GrpcStreamReader : public FlightStreamReader {
public:
FlightIpcMessageReader(std::unique_ptr<ClientRpc> rpc,
std::unique_ptr<grpc::ClientReader<pb::FlightData>> stream)
: rpc_(std::move(rpc)), stream_(std::move(stream)), stream_finished_(false) {}
GrpcStreamReader();

static Status Open(std::unique_ptr<ClientRpc> rpc,
std::unique_ptr<grpc::ClientReader<pb::FlightData>> stream,
std::unique_ptr<GrpcStreamReader>* out);
std::shared_ptr<Schema> schema() const override;
Status Next(FlightStreamChunk* out) override;
void Cancel() override;

private:
friend class GrpcIpcMessageReader;
std::unique_ptr<ipc::RecordBatchReader> batch_reader_;
std::shared_ptr<Buffer> last_app_metadata_;
std::shared_ptr<ClientRpc> rpc_;
};

class GrpcIpcMessageReader : public ipc::MessageReader {
public:
GrpcIpcMessageReader(GrpcStreamReader* reader, std::shared_ptr<ClientRpc> rpc,
std::unique_ptr<grpc::ClientReader<pb::FlightData>> stream)
: flight_reader_(reader),
rpc_(rpc),
stream_(std::move(stream)),
stream_finished_(false) {}

Status ReadNextMessage(std::unique_ptr<ipc::Message>* out) override {
if (stream_finished_) {
*out = nullptr;
flight_reader_->last_app_metadata_ = nullptr;
return Status::OK();
}
internal::FlightData data;
if (!internal::ReadPayload(stream_.get(), &data)) {
// Stream is completed
stream_finished_ = true;
*out = nullptr;
flight_reader_->last_app_metadata_ = nullptr;
return OverrideWithServerError(Status::OK());
}
// Validate IPC message
auto st = data.OpenMessage(out);
if (!st.ok()) {
flight_reader_->last_app_metadata_ = nullptr;
return OverrideWithServerError(std::move(st));
}
flight_reader_->last_app_metadata_ = data.app_metadata;
return Status::OK();
}

Expand All @@ -162,23 +197,93 @@ class FlightIpcMessageReader : public ipc::MessageReader {
return std::move(st);
}

private:
GrpcStreamReader* flight_reader_;
// The RPC context lifetime must be coupled to the ClientReader
std::unique_ptr<ClientRpc> rpc_;
std::shared_ptr<ClientRpc> rpc_;
std::unique_ptr<grpc::ClientReader<pb::FlightData>> stream_;
bool stream_finished_;
};

GrpcStreamReader::GrpcStreamReader() {}

Status GrpcStreamReader::Open(std::unique_ptr<ClientRpc> rpc,
std::unique_ptr<grpc::ClientReader<pb::FlightData>> stream,
std::unique_ptr<GrpcStreamReader>* out) {
*out = std::unique_ptr<GrpcStreamReader>(new GrpcStreamReader);
out->get()->rpc_ = std::move(rpc);
std::unique_ptr<GrpcIpcMessageReader> message_reader(
new GrpcIpcMessageReader(out->get(), out->get()->rpc_, std::move(stream)));
return ipc::RecordBatchStreamReader::Open(std::move(message_reader),
&(*out)->batch_reader_);
}

std::shared_ptr<Schema> GrpcStreamReader::schema() const {
return batch_reader_->schema();
}

Status GrpcStreamReader::Next(FlightStreamChunk* out) {
out->app_metadata = nullptr;
RETURN_NOT_OK(batch_reader_->ReadNext(&out->data));
out->app_metadata = std::move(last_app_metadata_);
return Status::OK();
}

void GrpcStreamReader::Cancel() { rpc_->context.TryCancel(); }

// Similarly, the next two classes are intertwined. In order to get
// application-specific metadata to the IpcPayloadWriter,
// DoPutPayloadWriter takes a pointer to
// GrpcStreamWriter. GrpcStreamWriter updates a metadata field on
// write; DoPutPayloadWriter reads that metadata field to determine
// what to write.

class DoPutPayloadWriter;
class GrpcStreamWriter : public FlightStreamWriter {
public:
~GrpcStreamWriter() = default;

GrpcStreamWriter() : app_metadata_(nullptr), batch_writer_(nullptr) {}

static Status Open(
const FlightDescriptor& descriptor, const std::shared_ptr<Schema>& schema,
std::unique_ptr<ClientRpc> rpc, std::unique_ptr<pb::PutResult> response,
std::shared_ptr<grpc::ClientReaderWriter<pb::FlightData, pb::PutResult>> writer,
std::unique_ptr<FlightStreamWriter>* out);

Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override {
return WriteWithMetadata(batch, nullptr, allow_64bit);
}
Status WriteWithMetadata(const RecordBatch& batch, std::shared_ptr<Buffer> app_metadata,
bool allow_64bit = false) override {
app_metadata_ = app_metadata;
return batch_writer_->WriteRecordBatch(batch, allow_64bit);
}
void set_memory_pool(MemoryPool* pool) override {
batch_writer_->set_memory_pool(pool);
}
Status Close() override { return batch_writer_->Close(); }

private:
friend class DoPutPayloadWriter;
std::shared_ptr<Buffer> app_metadata_;
std::unique_ptr<ipc::RecordBatchWriter> batch_writer_;
};

/// A IpcPayloadWriter implementation that writes to a DoPut stream
class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter {
public:
DoPutPayloadWriter(const FlightDescriptor& descriptor, std::unique_ptr<ClientRpc> rpc,
std::unique_ptr<protocol::PutResult> response,
std::unique_ptr<grpc::ClientWriter<pb::FlightData>> writer)
DoPutPayloadWriter(
const FlightDescriptor& descriptor, std::unique_ptr<ClientRpc> rpc,
std::unique_ptr<pb::PutResult> response,
std::shared_ptr<grpc::ClientReaderWriter<pb::FlightData, pb::PutResult>> writer,
GrpcStreamWriter* stream_writer)
: descriptor_(descriptor),
rpc_(std::move(rpc)),
response_(std::move(response)),
writer_(std::move(writer)),
first_payload_(true) {}
first_payload_(true),
stream_writer_(stream_writer) {}

~DoPutPayloadWriter() override = default;

Expand All @@ -201,6 +306,9 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter {
}
RETURN_NOT_OK(Buffer::FromString(str_descr, &payload.descriptor));
first_payload_ = false;
} else if (ipc_payload.type == ipc::Message::RECORD_BATCH &&
stream_writer_->app_metadata_) {
payload.app_metadata = std::move(stream_writer_->app_metadata_);
}

if (!internal::WritePayload(payload, writer_.get())) {
Expand All @@ -211,6 +319,10 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter {

Status Close() override {
bool finished_writes = writer_->WritesDone();
// Drain the read side to avoid hanging
pb::PutResult message;
while (writer_->Read(&message)) {
}
RETURN_NOT_OK(internal::FromGrpcStatus(writer_->Finish()));
if (!finished_writes) {
return Status::UnknownError(
Expand All @@ -223,9 +335,47 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter {
// TODO: there isn't a way to access this as a user.
const FlightDescriptor descriptor_;
std::unique_ptr<ClientRpc> rpc_;
std::unique_ptr<protocol::PutResult> response_;
std::unique_ptr<grpc::ClientWriter<pb::FlightData>> writer_;
std::unique_ptr<pb::PutResult> response_;
std::shared_ptr<grpc::ClientReaderWriter<pb::FlightData, pb::PutResult>> writer_;
bool first_payload_;
GrpcStreamWriter* stream_writer_;
};

Status GrpcStreamWriter::Open(
const FlightDescriptor& descriptor, const std::shared_ptr<Schema>& schema,
std::unique_ptr<ClientRpc> rpc, std::unique_ptr<pb::PutResult> response,
std::shared_ptr<grpc::ClientReaderWriter<pb::FlightData, pb::PutResult>> writer,
std::unique_ptr<FlightStreamWriter>* out) {
std::unique_ptr<GrpcStreamWriter> result(new GrpcStreamWriter);
std::unique_ptr<ipc::internal::IpcPayloadWriter> payload_writer(new DoPutPayloadWriter(
descriptor, std::move(rpc), std::move(response), writer, result.get()));
RETURN_NOT_OK(ipc::internal::OpenRecordBatchWriter(std::move(payload_writer), schema,
&result->batch_writer_));
*out = std::move(result);
return Status::OK();
}

FlightMetadataReader::~FlightMetadataReader() = default;

class GrpcMetadataReader : public FlightMetadataReader {
public:
explicit GrpcMetadataReader(
std::shared_ptr<grpc::ClientReaderWriter<pb::FlightData, pb::PutResult>> reader)
: reader_(reader) {}

Status ReadMetadata(std::shared_ptr<Buffer>* out) override {
pb::PutResult message;
if (reader_->Read(&message)) {
*out = Buffer::FromString(std::move(*message.release_app_metadata()));
} else {
// Stream finished
*out = nullptr;
}
return Status::OK();
}

private:
std::shared_ptr<grpc::ClientReaderWriter<pb::FlightData, pb::PutResult>> reader_;
};

class FlightClient::FlightClientImpl {
Expand Down Expand Up @@ -367,7 +517,7 @@ class FlightClient::FlightClientImpl {
}

Status DoGet(const FlightCallOptions& options, const Ticket& ticket,
std::unique_ptr<RecordBatchReader>* out) {
std::unique_ptr<FlightStreamReader>* out) {
pb::Ticket pb_ticket;
internal::ToProto(ticket, &pb_ticket);

Expand All @@ -376,25 +526,25 @@ class FlightClient::FlightClientImpl {
std::unique_ptr<grpc::ClientReader<pb::FlightData>> stream(
stub_->DoGet(&rpc->context, pb_ticket));

std::unique_ptr<ipc::MessageReader> message_reader(
new FlightIpcMessageReader(std::move(rpc), std::move(stream)));
return ipc::RecordBatchStreamReader::Open(std::move(message_reader), out);
std::unique_ptr<GrpcStreamReader> reader;
RETURN_NOT_OK(GrpcStreamReader::Open(std::move(rpc), std::move(stream), &reader));
*out = std::move(reader);
return Status::OK();
}

Status DoPut(const FlightCallOptions& options, const FlightDescriptor& descriptor,
const std::shared_ptr<Schema>& schema,
std::unique_ptr<ipc::RecordBatchWriter>* out) {
std::unique_ptr<FlightStreamWriter>* out,
std::unique_ptr<FlightMetadataReader>* reader) {
std::unique_ptr<ClientRpc> rpc(new ClientRpc(options));
RETURN_NOT_OK(rpc->SetToken(auth_handler_.get()));
std::unique_ptr<protocol::PutResult> response(new protocol::PutResult);
std::unique_ptr<grpc::ClientWriter<pb::FlightData>> writer(
stub_->DoPut(&rpc->context, response.get()));

std::unique_ptr<ipc::internal::IpcPayloadWriter> payload_writer(
new DoPutPayloadWriter(descriptor, std::move(rpc), std::move(response),
std::move(writer)));
std::unique_ptr<pb::PutResult> response(new pb::PutResult);
std::shared_ptr<grpc::ClientReaderWriter<pb::FlightData, pb::PutResult>> writer(
stub_->DoPut(&rpc->context));

return ipc::internal::OpenRecordBatchWriter(std::move(payload_writer), schema, out);
*reader = std::unique_ptr<FlightMetadataReader>(new GrpcMetadataReader(writer));
return GrpcStreamWriter::Open(descriptor, schema, std::move(rpc), std::move(response),
writer, out);
}

private:
Expand Down Expand Up @@ -449,15 +599,16 @@ Status FlightClient::ListFlights(const FlightCallOptions& options,
}

Status FlightClient::DoGet(const FlightCallOptions& options, const Ticket& ticket,
std::unique_ptr<RecordBatchReader>* stream) {
std::unique_ptr<FlightStreamReader>* stream) {
return impl_->DoGet(options, ticket, stream);
}

Status FlightClient::DoPut(const FlightCallOptions& options,
const FlightDescriptor& descriptor,
const std::shared_ptr<Schema>& schema,
std::unique_ptr<ipc::RecordBatchWriter>* stream) {
return impl_->DoPut(options, descriptor, schema, stream);
std::unique_ptr<FlightStreamWriter>* stream,
std::unique_ptr<FlightMetadataReader>* reader) {
return impl_->DoPut(options, descriptor, schema, stream, reader);
}

} // namespace flight
Expand Down
Loading