Skip to content

Commit

Permalink
Clean up C++ Flight integration client
Browse files Browse the repository at this point in the history
  • Loading branch information
David Li authored and David Li committed Feb 5, 2019
1 parent 3e185cb commit 65d6ba2
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 46 deletions.
15 changes: 8 additions & 7 deletions cpp/src/arrow/flight/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,14 @@ class FlightClient;
/// DoPut stream.
class FlightStreamWriter : public ipc::RecordBatchWriter {
public:
explicit FlightStreamWriter(std::unique_ptr<ClientRpc>&& rpc,
explicit FlightStreamWriter(std::unique_ptr<ClientRpc> rpc,
const FlightDescriptor& descriptor,
const std::shared_ptr<Schema>& schema)
: rpc_{std::move(rpc)},
descriptor_{descriptor},
schema_{schema},
pool_{default_memory_pool()} {}
const std::shared_ptr<Schema>& schema,
MemoryPool* pool = default_memory_pool())
: rpc_(std::move(rpc)),
descriptor_(descriptor),
schema_(schema),
pool_(pool) {}

Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override {
IpcPayload payload;
Expand Down Expand Up @@ -154,7 +155,7 @@ class FlightStreamWriter : public ipc::RecordBatchWriter {
private:
/// \brief Set the gRPC writer backing this Flight stream.
/// \param [in] writer the gRPC writer
void set_stream(std::unique_ptr<grpc::ClientWriter<pb::FlightData>>&& writer) {
void set_stream(std::unique_ptr<grpc::ClientWriter<pb::FlightData>> writer) {
writer_ = std::move(writer);
}

Expand Down
4 changes: 3 additions & 1 deletion cpp/src/arrow/flight/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ class ARROW_EXPORT FlightClient {
Status DoGet(const Ticket& ticket, const std::shared_ptr<Schema>& schema,
std::unique_ptr<RecordBatchReader>* stream);

/// \brief Upload data to a Flight described by the given descriptor.
/// \brief Upload data to a Flight described by the given
/// descriptor. The caller must call Close() on the returned stream
/// once they are done writing.
/// \param[in] descriptor the descriptor of the stream
/// \param[in] schema the schema for the data to upload
/// \param[out] stream a writer to write record batches to
Expand Down
10 changes: 5 additions & 5 deletions cpp/src/arrow/flight/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ namespace flight {
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, MESSAGE); \
}

class ARROW_EXPORT FlightMessageReaderImpl : public FlightMessageReader {
class FlightMessageReaderImpl : public FlightMessageReader {
public:
FlightMessageReaderImpl(const FlightDescriptor& descriptor,
std::shared_ptr<Schema> schema,
grpc::ServerReader<pb::FlightData>* reader)
: descriptor_{descriptor},
schema_{schema},
reader_{reader},
stream_finished_{false} {}
: descriptor_(descriptor),
schema_(schema),
reader_(reader),
stream_finished_(false) {}

const FlightDescriptor& descriptor() const override { return descriptor_; }

Expand Down
93 changes: 62 additions & 31 deletions cpp/src/arrow/flight/test-integration-client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,60 @@ DEFINE_string(host, "localhost", "Server port to connect to");
DEFINE_int32(port, 31337, "Server port to connect to");
DEFINE_string(path, "", "Resource path to request");

/// \brief Helper to read a RecordBatchReader into a Table.
arrow::Status ReadToTable(std::unique_ptr<arrow::RecordBatchReader>& reader,
std::shared_ptr<arrow::Table>* retrieved_data) {
std::vector<std::shared_ptr<arrow::RecordBatch>> retrieved_chunks;
std::shared_ptr<arrow::RecordBatch> chunk;
while (true) {
RETURN_NOT_OK(reader->ReadNext(&chunk));
if (chunk == nullptr) break;
retrieved_chunks.push_back(chunk);
}
return arrow::Table::FromRecordBatches(reader->schema(), retrieved_chunks,
retrieved_data);
}

/// \brief Helper to read a JsonReader into a Table.
arrow::Status ReadToTable(std::unique_ptr<arrow::ipc::internal::json::JsonReader>& reader,
std::shared_ptr<arrow::Table>* retrieved_data) {
std::vector<std::shared_ptr<arrow::RecordBatch>> retrieved_chunks;
std::shared_ptr<arrow::RecordBatch> chunk;
for (int i = 0; i < reader->num_record_batches(); i++) {
RETURN_NOT_OK(reader->ReadRecordBatch(i, &chunk));
retrieved_chunks.push_back(chunk);
}
return arrow::Table::FromRecordBatches(reader->schema(), retrieved_chunks,
retrieved_data);
}

/// \brief Helper to copy a RecordBatchReader to a RecordBatchWriter.
arrow::Status CopyReaderToWriter(std::unique_ptr<arrow::RecordBatchReader>& reader,
std::unique_ptr<arrow::ipc::RecordBatchWriter>& writer) {
while (true) {
std::shared_ptr<arrow::RecordBatch> chunk;
RETURN_NOT_OK(reader->ReadNext(&chunk));
if (chunk == nullptr) break;
RETURN_NOT_OK(writer->WriteRecordBatch(*chunk));
}
return writer->Close();
}

/// \brief Helper to read a flight into a Table.
arrow::Status ConsumeFlightLocation(const arrow::flight::Location& location,
const arrow::flight::Ticket& ticket,
const std::shared_ptr<arrow::Schema>& schema,
std::shared_ptr<arrow::Table>* retrieved_data) {
std::unique_ptr<arrow::flight::FlightClient> read_client;
RETURN_NOT_OK(
arrow::flight::FlightClient::Connect(location.host, location.port, &read_client));

std::unique_ptr<arrow::RecordBatchReader> stream;
RETURN_NOT_OK(read_client->DoGet(ticket, schema, &stream));

return ReadToTable(stream, retrieved_data);
}

int main(int argc, char** argv) {
gflags::SetUsageMessage("Integration testing client for Flight.");
gflags::ParseCommandLineFlags(&argc, &argv, true);
Expand All @@ -57,21 +111,14 @@ int main(int argc, char** argv) {
ABORT_NOT_OK(arrow::ipc::internal::json::JsonReader::Open(arrow::default_memory_pool(),
in_file, &reader));

std::shared_ptr<arrow::Table> original_data;
ABORT_NOT_OK(ReadToTable(reader, &original_data));

std::unique_ptr<arrow::ipc::RecordBatchWriter> write_stream;
ABORT_NOT_OK(client->DoPut(descr, reader->schema(), &write_stream));

std::vector<std::shared_ptr<arrow::RecordBatch>> original_chunks;
for (int i = 0; i < reader->num_record_batches(); i++) {
std::shared_ptr<arrow::RecordBatch> batch;
ABORT_NOT_OK(reader->ReadRecordBatch(i, &batch));
original_chunks.push_back(batch);
ABORT_NOT_OK(write_stream->WriteRecordBatch(*batch));
}
ABORT_NOT_OK(write_stream->Close());

std::shared_ptr<arrow::Table> original_data;
ABORT_NOT_OK(
arrow::Table::FromRecordBatches(reader->schema(), original_chunks, &original_data));
std::unique_ptr<arrow::RecordBatchReader> table_reader(
new arrow::TableBatchReader(*original_data));
ABORT_NOT_OK(CopyReaderToWriter(table_reader, write_stream));

// 2. Get the ticket for the data.
std::unique_ptr<arrow::flight::FlightInfo> info;
Expand All @@ -97,26 +144,10 @@ int main(int argc, char** argv) {
std::cout << "Verifying location " << location.host << ':' << location.port
<< std::endl;
// 3. Download the data from the server.
std::unique_ptr<arrow::flight::FlightClient> read_client;
ABORT_NOT_OK(arrow::flight::FlightClient::Connect(location.host, location.port,
&read_client));

std::unique_ptr<arrow::RecordBatchReader> stream;
ABORT_NOT_OK(read_client->DoGet(ticket, schema, &stream));

std::vector<std::shared_ptr<arrow::RecordBatch>> retrieved_chunks;
std::shared_ptr<arrow::RecordBatch> chunk;
while (true) {
ABORT_NOT_OK(stream->ReadNext(&chunk));
if (chunk == nullptr) break;
retrieved_chunks.push_back(chunk);
}

// 4. Validate that the data is equal.
std::shared_ptr<arrow::Table> retrieved_data;
ABORT_NOT_OK(
arrow::Table::FromRecordBatches(schema, retrieved_chunks, &retrieved_data));
ABORT_NOT_OK(ConsumeFlightLocation(location, ticket, schema, &retrieved_data));

// 4. Validate that the data is equal.
if (!original_data->Equals(*retrieved_data)) {
std::cerr << "Data does not match!" << std::endl;
return 1;
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/flight/test-integration-server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class FlightIntegrationTestServer : public FlightServerBase {

auto data = uploaded_chunks.find(request.path[0]);
if (data == uploaded_chunks.end()) {
return Status::KeyError("Could not find flight.");
return Status::KeyError("Could not find flight.", request.path[0]);
}
auto flight = data->second;

Expand All @@ -72,7 +72,7 @@ class FlightIntegrationTestServer : public FlightServerBase {
std::unique_ptr<FlightDataStream>* data_stream) override {
auto data = uploaded_chunks.find(request.ticket);
if (data == uploaded_chunks.end()) {
return Status::KeyError("Could not find flight.");
return Status::KeyError("Could not find flight.", request.ticket);
}
auto flight = data->second;

Expand Down

0 comments on commit 65d6ba2

Please sign in to comment.