Skip to content

Commit

Permalink
apacheGH-38573: [Java][FlightRPC] Try all locations in JDBC driver
Browse files Browse the repository at this point in the history
  • Loading branch information
lidavidm committed Feb 16, 2024
1 parent a03d957 commit ba56395
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,8 @@ public FlightClient build() {

builder
.maxTraceEvents(MAX_CHANNEL_TRACE_EVENTS)
.maxInboundMessageSize(maxInboundMessageSize);
.maxInboundMessageSize(maxInboundMessageSize)
.maxInboundMetadataSize(maxInboundMessageSize);
return new FlightClient(allocator, builder.build(), middleware);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,26 +116,47 @@ public List<CloseableEndpointStreamPair> getStreams(final FlightInfo flightInfo)
sqlClient.getStream(endpoint.getTicket(), getOptions()), null));
} else {
// Clone the builder and then set the new endpoint on it.
// GH-38573: This code currently only tries the first Location and treats a failure as fatal.
// This should be changed to try other Locations that are available.


// GH-38574: Currently a new FlightClient will be made for each partition that returns a non-empty Location
// then disposed of. It may be better to cache clients because a server may report the same Locations.
// It would also be good to identify when the reported location is the same as the original connection's
// Location and skip creating a FlightClient in that scenario.
final URI endpointUri = endpoint.getLocations().get(0).getUri();
final Builder builderForEndpoint = new Builder(ArrowFlightSqlClientHandler.this.builder)
.withHost(endpointUri.getHost())
.withPort(endpointUri.getPort())
.withEncryption(endpointUri.getScheme().equals(LocationSchemes.GRPC_TLS));

final ArrowFlightSqlClientHandler endpointHandler = builderForEndpoint.build();
try {
endpoints.add(new CloseableEndpointStreamPair(
endpointHandler.sqlClient.getStream(endpoint.getTicket(),
endpointHandler.getOptions()), endpointHandler.sqlClient));
} catch (Exception ex) {
AutoCloseables.close(endpointHandler);
List<Exception> exceptions = new ArrayList<>();
CloseableEndpointStreamPair stream = null;
for (Location location : endpoint.getLocations()) {
final URI endpointUri = location.getUri();
final Builder builderForEndpoint = new Builder(ArrowFlightSqlClientHandler.this.builder)
.withHost(endpointUri.getHost())
.withPort(endpointUri.getPort())
.withEncryption(endpointUri.getScheme().equals(LocationSchemes.GRPC_TLS));

ArrowFlightSqlClientHandler endpointHandler = null;
try {
endpointHandler = builderForEndpoint.build();
stream = new CloseableEndpointStreamPair(
endpointHandler.sqlClient.getStream(endpoint.getTicket(),
endpointHandler.getOptions()), endpointHandler.sqlClient);
// Make sure we actually get data from the server
stream.getStream().getSchema();
} catch (Exception ex) {
if (endpointHandler != null) {
AutoCloseables.close(endpointHandler);
}
exceptions.add(ex);
continue;
}
break;
}
if (stream != null) {
endpoints.add(stream);
} else if (exceptions.isEmpty()) {
// This should never happen...
throw new IllegalStateException("Could not connect to endpoint and no errors occurred");
} else {
Exception ex = exceptions.remove(0);
while (!exceptions.isEmpty()) {
ex.addSuppressed(exceptions.remove(exceptions.size() - 1));
}
throw ex;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,14 @@ private CloseableEndpointStreamPair next(final EndpointStreamSupplier endpointSt
if (endpoint != null) {
return endpoint;
}
} catch (final ExecutionException | InterruptedException | CancellationException e) {
} catch (final ExecutionException e) {
// Unwrap one layer
final Throwable cause = e.getCause();
if (cause instanceof FlightRuntimeException) {
throw (FlightRuntimeException) cause;
}
throw AvaticaConnection.HELPER.wrap(e.getMessage(), e);
} catch (InterruptedException | CancellationException e) {
throw AvaticaConnection.HELPER.wrap(e.getMessage(), e);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
Expand All @@ -49,7 +50,10 @@
import org.apache.arrow.driver.jdbc.utils.PartitionedFlightSqlProducer;
import org.apache.arrow.flight.FlightEndpoint;
import org.apache.arrow.flight.FlightProducer;
import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.FlightStatusCode;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
Expand All @@ -63,6 +67,7 @@
import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.Test;
import org.junit.jupiter.api.Assertions;
import org.junit.rules.ErrorCollector;

import com.google.common.collect.ImmutableSet;
Expand Down Expand Up @@ -351,7 +356,7 @@ public void testShouldInterruptFlightStreamsIfQueryIsCancelledMidProcessingForTi
.toString(),
anyOf(is(format("Error while executing SQL \"%s\": Query canceled", query)),
allOf(containsString(format("Error while executing SQL \"%s\"", query)),
containsString("CANCELLED"))));
anyOf(containsString("CANCELLED"), containsString("Cancelling")))));
}
}

Expand Down Expand Up @@ -455,6 +460,90 @@ allocator, forGrpcInsecure("localhost", 0), rootProducer)
}
}

@Test
public void testPartitionedFlightServerIgnoreFailure() throws Exception {
final Schema schema = new Schema(
Collections.singletonList(Field.nullablePrimitive("int_column", new ArrowType.Int(32, true))));
try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) {
final FlightEndpoint firstEndpoint =
new FlightEndpoint(new Ticket("first".getBytes(StandardCharsets.UTF_8)),
Location.forGrpcInsecure("127.0.0.2", 1234),
Location.forGrpcInsecure("127.0.0.3", 1234));

try (final PartitionedFlightSqlProducer rootProducer = new PartitionedFlightSqlProducer(
schema, firstEndpoint);
FlightServer rootServer = FlightServer.builder(
allocator, forGrpcInsecure("localhost", 0), rootProducer)
.build()
.start();
Connection newConnection = DriverManager.getConnection(String.format(
"jdbc:arrow-flight-sql://%s:%d/?useEncryption=false",
rootServer.getLocation().getUri().getHost(), rootServer.getPort()));
Statement newStatement = newConnection.createStatement()) {
final SQLException e = Assertions.assertThrows(SQLException.class, () -> {
ResultSet result = newStatement.executeQuery("Select partitioned_data");
while (result.next()) {
}
});
final Throwable cause = e.getCause();
Assertions.assertTrue(cause instanceof FlightRuntimeException);
final FlightRuntimeException fre = (FlightRuntimeException) cause;
Assertions.assertEquals(FlightStatusCode.UNAVAILABLE, fre.status().code());
}
}
}

@Test
public void testPartitionedFlightServerAllFailure() throws Exception {
// Arrange
final Schema schema = new Schema(
Collections.singletonList(Field.nullablePrimitive("int_column", new ArrowType.Int(32, true))));
try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
VectorSchemaRoot firstPartition = VectorSchemaRoot.create(schema, allocator)) {
firstPartition.setRowCount(1);
((IntVector) firstPartition.getVector(0)).set(0, 1);

// Construct the data-only nodes first.
FlightProducer firstProducer = new PartitionedFlightSqlProducer.DataOnlyFlightSqlProducer(
new Ticket("first".getBytes(StandardCharsets.UTF_8)), firstPartition);

final FlightServer.Builder firstBuilder = FlightServer.builder(
allocator, forGrpcInsecure("localhost", 0), firstProducer);

// Run the data-only nodes so that we can get the Locations they are running at.
try (FlightServer firstServer = firstBuilder.build()) {
firstServer.start();
final Location badLocation = Location.forGrpcInsecure("127.0.0.2", 1234);
final FlightEndpoint firstEndpoint =
new FlightEndpoint(new Ticket("first".getBytes(StandardCharsets.UTF_8)),
badLocation, firstServer.getLocation());

// Finally start the root node.
try (final PartitionedFlightSqlProducer rootProducer = new PartitionedFlightSqlProducer(
schema, firstEndpoint);
FlightServer rootServer = FlightServer.builder(
allocator, forGrpcInsecure("localhost", 0), rootProducer)
.build()
.start();
Connection newConnection = DriverManager.getConnection(String.format(
"jdbc:arrow-flight-sql://%s:%d/?useEncryption=false",
rootServer.getLocation().getUri().getHost(), rootServer.getPort()));
Statement newStatement = newConnection.createStatement();
// Act
ResultSet result = newStatement.executeQuery("Select partitioned_data")) {
List<Integer> resultData = new ArrayList<>();
while (result.next()) {
resultData.add(result.getInt(1));
}

// Assert
assertEquals(firstPartition.getRowCount(), resultData.size());
assertTrue(resultData.contains(((IntVector) firstPartition.getVector(0)).get(0)));
}
}
}
}

@Test
public void testShouldRunSelectQueryWithEmptyVectorsEmbedded() throws Exception {
try (Statement statement = connection.createStatement();
Expand Down

0 comments on commit ba56395

Please sign in to comment.