Skip to content

Commit

Permalink
Test all returned locations in Flight integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
David Li authored and David Li committed Feb 5, 2019
1 parent 905ef38 commit 3cb51ba
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 50 deletions.
79 changes: 46 additions & 33 deletions cpp/src/arrow/flight/test-integration-client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
// specific language governing permissions and limitations
// under the License.

// Client implementation for Flight integration testing. Requests the given
// path from the Flight server, which reads that file and sends it as a stream
// to the client. The client writes the server stream to the IPC file format at
// the given output file path. The integration test script then uses the
// existing integration test tools to compare the output binary with the
// original JSON
// Client implementation for Flight integration testing. Loads
// RecordBatches from the given JSON file and uploads them to the
// Flight server, which stores the data and schema in memory. The
// client then requests the data from the server and compares it to
// the data originally uploaded.

#include <iostream>
#include <memory>
Expand Down Expand Up @@ -76,11 +75,14 @@ int main(int argc, char** argv) {
}
ABORT_NOT_OK(write_stream->Close());

std::shared_ptr<arrow::Table> original_data;
ABORT_NOT_OK(
arrow::Table::FromRecordBatches(reader->schema(), original_chunks, &original_data));

// 2. Get the ticket for the data.
std::unique_ptr<arrow::flight::FlightInfo> info;
ABORT_NOT_OK(client->GetFlightInfo(descr, &info));

// 3. Download the data from the server.
std::shared_ptr<arrow::Schema> schema;
ABORT_NOT_OK(info->GetSchema(&schema));

Expand All @@ -89,32 +91,43 @@ int main(int argc, char** argv) {
return -1;
}

arrow::flight::Ticket ticket = info->endpoints()[0].ticket;
std::unique_ptr<arrow::RecordBatchReader> stream;
ABORT_NOT_OK(client->DoGet(ticket, schema, &stream));

std::vector<std::shared_ptr<arrow::RecordBatch>> retrieved_chunks;
std::shared_ptr<arrow::RecordBatch> chunk;
while (true) {
ABORT_NOT_OK(stream->ReadNext(&chunk));
if (chunk == nullptr) break;
retrieved_chunks.push_back(chunk);
for (const arrow::flight::FlightEndpoint& endpoint : info->endpoints()) {
const auto& ticket = endpoint.ticket;

auto locations = endpoint.locations;
if (locations.size() == 0) {
locations = {arrow::flight::Location{FLAGS_host, FLAGS_port}};
}

for (const auto location : locations) {
std::cout << "Verifying location " << location.host << ':' << location.port
<< std::endl;
// 3. Download the data from the server.
std::unique_ptr<arrow::flight::FlightClient> read_client;
ABORT_NOT_OK(arrow::flight::FlightClient::Connect(location.host, location.port,
&read_client));

std::unique_ptr<arrow::RecordBatchReader> stream;
ABORT_NOT_OK(read_client->DoGet(ticket, schema, &stream));

std::vector<std::shared_ptr<arrow::RecordBatch>> retrieved_chunks;
std::shared_ptr<arrow::RecordBatch> chunk;
while (true) {
ABORT_NOT_OK(stream->ReadNext(&chunk));
if (chunk == nullptr) break;
retrieved_chunks.push_back(chunk);
}

// 4. Validate that the data is equal.
std::shared_ptr<arrow::Table> retrieved_data;
ABORT_NOT_OK(
arrow::Table::FromRecordBatches(schema, retrieved_chunks, &retrieved_data));

if (!original_data->Equals(*retrieved_data)) {
std::cerr << "Data does not match!" << std::endl;
return 1;
}
}
}

// 4. Validate that the data is equal.

std::shared_ptr<arrow::Table> original_data;
std::shared_ptr<arrow::Table> retrieved_data;

ABORT_NOT_OK(
arrow::Table::FromRecordBatches(reader->schema(), original_chunks, &original_data));
ABORT_NOT_OK(
arrow::Table::FromRecordBatches(schema, retrieved_chunks, &retrieved_data));

if (!original_data->Equals(*retrieved_data)) {
std::cerr << "Data does not match!" << std::endl;
return 1;
}

return 0;
}
10 changes: 5 additions & 5 deletions integration/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,12 +1004,12 @@ def _compare_flight_implementations(self, producer, consumer):
)
print('##########################################################')

for json_path in self.json_files:
print('==========================================================')
print('Testing file {0}'.format(json_path))
print('==========================================================')
with producer.flight_server():
for json_path in self.json_files:
print('==========================================================')
print('Testing file {0}'.format(json_path))
print('==========================================================')

with producer.flight_server():
# Have the client upload the file, then download and
# compare
consumer.flight_request(producer.FLIGHT_PORT, json_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.io.File;
import java.io.IOException;
import java.util.Collections;
import java.util.List;

import org.apache.arrow.flight.FlightClient;
Expand Down Expand Up @@ -51,7 +52,7 @@ private IntegrationTestClient() {
options = new Options();
options.addOption("j", "json", true, "json file");
options.addOption("host", true, "The host to connect to.");
options.addOption("port", true, "The port to connect to." );
options.addOption("port", true, "The port to connect to.");
}

public static void main(String[] args) {
Expand Down Expand Up @@ -109,18 +110,29 @@ private void run(String[] args) throws ParseException, IOException {
throw new RuntimeException("No endpoints returned from Flight server.");
}

// 3. Download the data from the server.
FlightStream stream = client.getStream(info.getEndpoints().get(0).getTicket());
VectorSchemaRoot downloadedRoot;
try (VectorSchemaRoot root = stream.getRoot()) {
downloadedRoot = VectorSchemaRoot.create(root.getSchema(), allocator);
VectorLoader loader = new VectorLoader(downloadedRoot);
VectorUnloader unloader = new VectorUnloader(root);
while (stream.next()) {
loader.load(unloader.getRecordBatch());
for (FlightEndpoint endpoint : info.getEndpoints()) {
// 3. Download the data from the server.
List<Location> locations = endpoint.getLocations();
if (locations.size() == 0) {
locations = Collections.singletonList(new Location(host, port));
}
}
for (Location location : locations) {
System.out.println("Verifying location " + location.getHost() + ":" + location.getPort());
FlightClient readClient = new FlightClient(allocator, location);
FlightStream stream = readClient.getStream(endpoint.getTicket());
VectorSchemaRoot downloadedRoot;
try (VectorSchemaRoot root = stream.getRoot()) {
downloadedRoot = VectorSchemaRoot.create(root.getSchema(), allocator);
VectorLoader loader = new VectorLoader(downloadedRoot);
VectorUnloader unloader = new VectorUnloader(root);
while (stream.next()) {
loader.load(unloader.getRecordBatch());
}
}

Validator.compareVectorSchemaRoot(jsonRoot, downloadedRoot);
// 4. Validate the data.
Validator.compareVectorSchemaRoot(jsonRoot, downloadedRoot);
}
}
}
}

0 comments on commit 3cb51ba

Please sign in to comment.