From 3cb51badd430a634dee32f3b73026f5d72102604 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 29 Jan 2019 13:44:51 -0500 Subject: [PATCH] Test all returned locations in Flight integration tests --- .../arrow/flight/test-integration-client.cc | 79 +++++++++++-------- integration/integration_test.py | 10 +-- .../integration/IntegrationTestClient.java | 36 ++++++--- 3 files changed, 75 insertions(+), 50 deletions(-) diff --git a/cpp/src/arrow/flight/test-integration-client.cc b/cpp/src/arrow/flight/test-integration-client.cc index a94001ff33713..64f9dc3cd158e 100644 --- a/cpp/src/arrow/flight/test-integration-client.cc +++ b/cpp/src/arrow/flight/test-integration-client.cc @@ -15,12 +15,11 @@ // specific language governing permissions and limitations // under the License. -// Client implementation for Flight integration testing. Requests the given -// path from the Flight server, which reads that file and sends it as a stream -// to the client. The client writes the server stream to the IPC file format at -// the given output file path. The integration test script then uses the -// existing integration test tools to compare the output binary with the -// original JSON +// Client implementation for Flight integration testing. Loads +// RecordBatches from the given JSON file and uploads them to the +// Flight server, which stores the data and schema in memory. The +// client then requests the data from the server and compares it to +// the data originally uploaded. #include #include @@ -76,11 +75,14 @@ int main(int argc, char** argv) { } ABORT_NOT_OK(write_stream->Close()); + std::shared_ptr original_data; + ABORT_NOT_OK( + arrow::Table::FromRecordBatches(reader->schema(), original_chunks, &original_data)); + // 2. Get the ticket for the data. std::unique_ptr info; ABORT_NOT_OK(client->GetFlightInfo(descr, &info)); - // 3. Download the data from the server. std::shared_ptr schema; ABORT_NOT_OK(info->GetSchema(&schema)); @@ -89,32 +91,43 @@ int main(int argc, char** argv) { return -1; } - arrow::flight::Ticket ticket = info->endpoints()[0].ticket; - std::unique_ptr stream; - ABORT_NOT_OK(client->DoGet(ticket, schema, &stream)); - - std::vector> retrieved_chunks; - std::shared_ptr chunk; - while (true) { - ABORT_NOT_OK(stream->ReadNext(&chunk)); - if (chunk == nullptr) break; - retrieved_chunks.push_back(chunk); + for (const arrow::flight::FlightEndpoint& endpoint : info->endpoints()) { + const auto& ticket = endpoint.ticket; + + auto locations = endpoint.locations; + if (locations.size() == 0) { + locations = {arrow::flight::Location{FLAGS_host, FLAGS_port}}; + } + + for (const auto location : locations) { + std::cout << "Verifying location " << location.host << ':' << location.port + << std::endl; + // 3. Download the data from the server. + std::unique_ptr read_client; + ABORT_NOT_OK(arrow::flight::FlightClient::Connect(location.host, location.port, + &read_client)); + + std::unique_ptr stream; + ABORT_NOT_OK(read_client->DoGet(ticket, schema, &stream)); + + std::vector> retrieved_chunks; + std::shared_ptr chunk; + while (true) { + ABORT_NOT_OK(stream->ReadNext(&chunk)); + if (chunk == nullptr) break; + retrieved_chunks.push_back(chunk); + } + + // 4. Validate that the data is equal. + std::shared_ptr retrieved_data; + ABORT_NOT_OK( + arrow::Table::FromRecordBatches(schema, retrieved_chunks, &retrieved_data)); + + if (!original_data->Equals(*retrieved_data)) { + std::cerr << "Data does not match!" << std::endl; + return 1; + } + } } - - // 4. Validate that the data is equal. - - std::shared_ptr original_data; - std::shared_ptr retrieved_data; - - ABORT_NOT_OK( - arrow::Table::FromRecordBatches(reader->schema(), original_chunks, &original_data)); - ABORT_NOT_OK( - arrow::Table::FromRecordBatches(schema, retrieved_chunks, &retrieved_data)); - - if (!original_data->Equals(*retrieved_data)) { - std::cerr << "Data does not match!" << std::endl; - return 1; - } - return 0; } diff --git a/integration/integration_test.py b/integration/integration_test.py index cef4e5697b29e..fc02d0712006c 100644 --- a/integration/integration_test.py +++ b/integration/integration_test.py @@ -1004,12 +1004,12 @@ def _compare_flight_implementations(self, producer, consumer): ) print('##########################################################') - for json_path in self.json_files: - print('==========================================================') - print('Testing file {0}'.format(json_path)) - print('==========================================================') + with producer.flight_server(): + for json_path in self.json_files: + print('==========================================================') + print('Testing file {0}'.format(json_path)) + print('==========================================================') - with producer.flight_server(): # Have the client upload the file, then download and # compare consumer.flight_request(producer.FLIGHT_PORT, json_path) diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java index 3c04a19c9de8d..ed450074a767a 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java @@ -19,6 +19,7 @@ import java.io.File; import java.io.IOException; +import java.util.Collections; import java.util.List; import org.apache.arrow.flight.FlightClient; @@ -51,7 +52,7 @@ private IntegrationTestClient() { options = new Options(); options.addOption("j", "json", true, "json file"); options.addOption("host", true, "The host to connect to."); - options.addOption("port", true, "The port to connect to." ); + options.addOption("port", true, "The port to connect to."); } public static void main(String[] args) { @@ -109,18 +110,29 @@ private void run(String[] args) throws ParseException, IOException { throw new RuntimeException("No endpoints returned from Flight server."); } - // 3. Download the data from the server. - FlightStream stream = client.getStream(info.getEndpoints().get(0).getTicket()); - VectorSchemaRoot downloadedRoot; - try (VectorSchemaRoot root = stream.getRoot()) { - downloadedRoot = VectorSchemaRoot.create(root.getSchema(), allocator); - VectorLoader loader = new VectorLoader(downloadedRoot); - VectorUnloader unloader = new VectorUnloader(root); - while (stream.next()) { - loader.load(unloader.getRecordBatch()); + for (FlightEndpoint endpoint : info.getEndpoints()) { + // 3. Download the data from the server. + List locations = endpoint.getLocations(); + if (locations.size() == 0) { + locations = Collections.singletonList(new Location(host, port)); } - } + for (Location location : locations) { + System.out.println("Verifying location " + location.getHost() + ":" + location.getPort()); + FlightClient readClient = new FlightClient(allocator, location); + FlightStream stream = readClient.getStream(endpoint.getTicket()); + VectorSchemaRoot downloadedRoot; + try (VectorSchemaRoot root = stream.getRoot()) { + downloadedRoot = VectorSchemaRoot.create(root.getSchema(), allocator); + VectorLoader loader = new VectorLoader(downloadedRoot); + VectorUnloader unloader = new VectorUnloader(root); + while (stream.next()) { + loader.load(unloader.getRecordBatch()); + } + } - Validator.compareVectorSchemaRoot(jsonRoot, downloadedRoot); + // 4. Validate the data. + Validator.compareVectorSchemaRoot(jsonRoot, downloadedRoot); + } + } } }