Skip to content

Commit

Permalink
WIP: [FlightRPC][C++][Java] Add fallback URI scheme
Browse files Browse the repository at this point in the history
  • Loading branch information
lidavidm committed Feb 14, 2024
1 parent 2422994 commit 4b14981
Show file tree
Hide file tree
Showing 10 changed files with 268 additions and 18 deletions.
5 changes: 5 additions & 0 deletions cpp/src/arrow/flight/flight_internals_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,11 @@ TEST(FlightTypes, LocationUnknownScheme) {
ASSERT_OK(Location::Parse("https://example.com/foo"));
}

TEST(FlightTypes, LocationFallback) {
EXPECT_EQ("arrow-flight-reuse-connection://", Location::ReuseConnection().ToString());
EXPECT_EQ("arrow-flight-reuse-connection", Location::ReuseConnection().scheme());
}

TEST(FlightTypes, RoundtripStatus) {
// Make sure status codes round trip through our conversions

Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/flight/types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,12 @@ arrow::Result<Location> Location::Parse(const std::string& uri_string) {
return location;
}

const Location& Location::ReuseConnection() {
static Location kFallback =
Location::Parse("arrow-flight-reuse-connection://").ValueOrDie();
return kFallback;
}

arrow::Result<Location> Location::ForGrpcTcp(const std::string& host, const int port) {
std::stringstream uri_string;
uri_string << "grpc+tcp://" << host << ':' << port;
Expand Down
8 changes: 8 additions & 0 deletions cpp/src/arrow/flight/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,14 @@ struct ARROW_FLIGHT_EXPORT Location {
/// \brief Initialize a location by parsing a URI string
static arrow::Result<Location> Parse(const std::string& uri_string);

/// \brief Get the fallback URI.
///
/// arrow-flight-reuse-connection:// means that a client may attempt to
/// reuse an existing connection to a Flight service to fetch data instead
/// of creating a new connection to one of the other locations listed in a
/// FlightEndpoint response.
static const Location& ReuseConnection();

/// \brief Initialize a location for a non-TLS, gRPC-based Flight
/// service from a host and port
/// \param[in] host The hostname to connect to
Expand Down
65 changes: 47 additions & 18 deletions docs/source/format/Flight.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,10 @@ A client that wishes to download the data would:
server and not a different location, then it can return an empty
list of locations. The client can then reuse the existing
connection to the original server to fetch data. Otherwise, the
client must connect to one of the indicated locations.
client must connect to one of the indicated locations. The server
may list "itself" as a location in this case as well without having
to know its public address via ``arrow-flight-fallback://?``. See
`Fallback URIs`_ below.

In this way, the locations inside an endpoint can also be thought
of as performing look-aside load balancing or service discovery
Expand Down Expand Up @@ -307,29 +310,55 @@ well, in which case any `authentication method supported by gRPC

.. _Mutual TLS (mTLS): https://grpc.io/docs/guides/auth/#supported-auth-mechanisms

Transport Implementations
=========================
Location URIs
=============

Flight is primarily defined in terms of its Protobuf and gRPC
specification below, but Arrow implementations may also support
alternative transports (see :ref:`status-flight-rpc`). In that case,
implementations should use the following URI schemes for the given
transport implementations:

+----------------------------+----------------------------+
| Transport | URI Scheme |
+============================+============================+
| gRPC (plaintext) | grpc: or grpc+tcp: |
+----------------------------+----------------------------+
| gRPC (TLS) | grpc+tls: |
+----------------------------+----------------------------+
| gRPC (Unix domain socket) | grpc+unix: |
+----------------------------+----------------------------+
| UCX_ (plaintext) | ucx: |
+----------------------------+----------------------------+
alternative transports (see :ref:`status-flight-rpc`). Clients and
servers need to know which transport to use for a given URI in a
Location, so Flight implementations should use the following URI
schemes for the given transports:

+----------------------------+--------------------------------+
| Transport | URI Scheme |
+============================+================================+
| Fallback | arrow-flight-reuse-connection: |
+----------------------------+--------------------------------+
| gRPC (plaintext) | grpc: or grpc+tcp: |
+----------------------------+--------------------------------+
| gRPC (TLS) | grpc+tls: |
+----------------------------+--------------------------------+
| gRPC (Unix domain socket) | grpc+unix: |
+----------------------------+--------------------------------+
| UCX_ (plaintext) | ucx: |
+----------------------------+--------------------------------+

.. _UCX: https://openucx.org/

Fallback URIs
-------------

"Fallback" above is not a particular transport. Instead, it means
that the client may try to execute DoGet against the same server (and
through the same connection) that it originally obtained the
FlightInfo from (i.e., that it called GetFlightInfo against).

For compatibility reasons, the URI should always be
``arrow-flight-reuse-connection://?``. (Java's URI implementation
does not accept ``scheme:`` or ``scheme://``, and C++'s implementation
does not accept an empty string; the chosen representation can be
parsed by both implementations, as well as Go's ``net/url`` and
Python's ``urllib.parse``.)

This allows the server to return "itself" as one possible location to
fetch data without having to know its own public address, which can be
useful in deployments where knowing this would be difficult or
impossible. For example, a developer may forward a remote service in
a cloud environment to their local machine; in this case, the remote
service would have no way to know the local hostname and port that it
is being accessed over.

Error Handling
==============

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,19 @@ Flight.Location toProtocol() {
return Flight.Location.newBuilder().setUri(uri.toString()).build();
}

/**
* Construct a special URI to indicate to clients that they may fetch data by reusing
* an existing connection to a Flight RPC server.
*/
public static Location reuseConnection() {
try {
return new Location(new URI(LocationSchemes.REUSE_CONNECTION, "", "", "", null));
} catch (URISyntaxException e) {
// This should never happen.
throw new IllegalArgumentException(e);
}
}

/**
* Construct a URI for a Flight+gRPC server without transport security.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public final class LocationSchemes {
public static final String GRPC_INSECURE = "grpc+tcp";
public static final String GRPC_DOMAIN_SOCKET = "grpc+unix";
public static final String GRPC_TLS = "grpc+tls";
public static final String REUSE_CONNECTION = "arrow-flight-reuse-connection";

private LocationSchemes() {
throw new AssertionError("Do not instantiate this class.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ public void fastPathDefaults() {
Assertions.assertFalse(ArrowMessage.ENABLE_ZERO_COPY_WRITE);
}

@Test
public void fallbackLocation() {
Assertions.assertEquals("arrow-flight-basic-operation://?",
Location.reuseConnection().getUri().toString());
}

/**
* ARROW-6017: we should be able to construct locations for unknown schemes.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,13 @@ public List<CloseableEndpointStreamPair> getStreams(final FlightInfo flightInfo)
// It would also be good to identify when the reported location is the same as the original connection's
// Location and skip creating a FlightClient in that scenario.
final URI endpointUri = endpoint.getLocations().get(0).getUri();

if (endpointUri.getScheme().equals(LocationSchemes.REUSE_CONNECTION)) {
endpoints.add(new CloseableEndpointStreamPair(
sqlClient.getStream(endpoint.getTicket(), getOptions()), null));
continue;
}

final Builder builderForEndpoint = new Builder(ArrowFlightSqlClientHandler.this.builder)
.withHost(endpointUri.getHost())
.withPort(endpointUri.getPort())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.CountDownLatch;

import org.apache.arrow.driver.jdbc.utils.CoreMockedSqlProducers;
import org.apache.arrow.driver.jdbc.utils.FallbackFlightSqlProducer;
import org.apache.arrow.driver.jdbc.utils.PartitionedFlightSqlProducer;
import org.apache.arrow.flight.FlightEndpoint;
import org.apache.arrow.flight.FlightProducer;
Expand All @@ -55,6 +57,7 @@
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
Expand All @@ -63,6 +66,7 @@
import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.Test;
import org.junit.jupiter.api.Assertions;
import org.junit.rules.ErrorCollector;

import com.google.common.collect.ImmutableSet;
Expand Down Expand Up @@ -455,6 +459,69 @@ allocator, forGrpcInsecure("localhost", 0), rootProducer)
}
}

@Test
public void testFallbackFlightServer() throws Exception {
final Schema schema = new Schema(
Collections.singletonList(Field.nullable("int_column", Types.MinorType.INT.getType())));
try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
VectorSchemaRoot resultData = VectorSchemaRoot.create(schema, allocator)) {
resultData.setRowCount(1);
((IntVector) resultData.getVector(0)).set(0, 1);

try (final FallbackFlightSqlProducer rootProducer = new FallbackFlightSqlProducer(resultData);
FlightServer rootServer = FlightServer.builder(
allocator, forGrpcInsecure("localhost", 0), rootProducer)
.build()
.start();
Connection newConnection = DriverManager.getConnection(String.format(
"jdbc:arrow-flight-sql://%s:%d/?useEncryption=false",
rootServer.getLocation().getUri().getHost(), rootServer.getPort()));
Statement newStatement = newConnection.createStatement();
ResultSet result = newStatement.executeQuery("fallback")) {
List<Integer> actualData = new ArrayList<>();
while (result.next()) {
actualData.add(result.getInt(1));
}

// Assert
assertEquals(resultData.getRowCount(), actualData.size());
assertTrue(actualData.contains(((IntVector) resultData.getVector(0)).get(0)));
}
}
}

@Test
public void testFallbackSecondFlightServer() throws Exception {
final Schema schema = new Schema(
Collections.singletonList(Field.nullable("int_column", Types.MinorType.INT.getType())));
try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
VectorSchemaRoot resultData = VectorSchemaRoot.create(schema, allocator)) {
resultData.setRowCount(1);
((IntVector) resultData.getVector(0)).set(0, 1);

try (final FallbackFlightSqlProducer rootProducer = new FallbackFlightSqlProducer(resultData);
FlightServer rootServer = FlightServer.builder(
allocator, forGrpcInsecure("localhost", 0), rootProducer)
.build()
.start();
Connection newConnection = DriverManager.getConnection(String.format(
"jdbc:arrow-flight-sql://%s:%d/?useEncryption=false",
rootServer.getLocation().getUri().getHost(), rootServer.getPort()));
Statement newStatement = newConnection.createStatement()) {

// TODO(https://github.com/apache/arrow/issues/38573)
// XXX: we could try to assert more structure but then we'd have to hardcode
// a particular exception chain which may be fragile
Assertions.assertThrows(SQLException.class, () -> {
try (ResultSet result = newStatement.executeQuery("fallback with error")) {
// Empty body
}
});

}
}
}

@Test
public void testShouldRunSelectQueryWithEmptyVectorsEmbedded() throws Exception {
try (Statement statement = connection.createStatement();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.arrow.driver.jdbc.utils;

import java.util.Collections;
import java.util.List;

import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightEndpoint;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.Result;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.flight.sql.BasicFlightSqlProducer;
import org.apache.arrow.flight.sql.impl.FlightSql;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.Schema;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.Message;

public class FallbackFlightSqlProducer extends BasicFlightSqlProducer {
private final VectorSchemaRoot data;

public FallbackFlightSqlProducer(VectorSchemaRoot resultData) {
this.data = resultData;
}

@Override
protected <T extends Message> List<FlightEndpoint> determineEndpoints(
T request, FlightDescriptor flightDescriptor, Schema schema) {
return Collections.emptyList();
}

@Override
public void createPreparedStatement(FlightSql.ActionCreatePreparedStatementRequest request,
CallContext context, StreamListener<Result> listener) {
final FlightSql.ActionCreatePreparedStatementResult.Builder resultBuilder =
FlightSql.ActionCreatePreparedStatementResult.newBuilder()
.setPreparedStatementHandle(request.getQueryBytes());

final ByteString datasetSchemaBytes = ByteString.copyFrom(data.getSchema().serializeAsMessage());

resultBuilder.setDatasetSchema(datasetSchemaBytes);
listener.onNext(new Result(Any.pack(resultBuilder.build()).toByteArray()));
listener.onCompleted();
}

@Override
public FlightInfo getFlightInfoStatement(
FlightSql.CommandStatementQuery command, CallContext context, FlightDescriptor descriptor) {
return getFlightInfo(descriptor, command.getQuery());
}

@Override
public FlightInfo getFlightInfoPreparedStatement(FlightSql.CommandPreparedStatementQuery command,
CallContext context, FlightDescriptor descriptor) {
return getFlightInfo(descriptor, command.getPreparedStatementHandle().toStringUtf8());
}

@Override
public void getStreamStatement(FlightSql.TicketStatementQuery ticket, CallContext context,
ServerStreamListener listener) {
listener.start(data);
listener.putNext();
listener.completed();
}

@Override
public void closePreparedStatement(FlightSql.ActionClosePreparedStatementRequest request,
CallContext context, StreamListener<Result> listener) {
listener.onCompleted();
}

private FlightInfo getFlightInfo(FlightDescriptor descriptor, String query) {
final List<FlightEndpoint> endpoints;
final Ticket ticket = new Ticket(
Any.pack(FlightSql.TicketStatementQuery.getDefaultInstance()).toByteArray());
if (query.equals("fallback")) {
endpoints = Collections.singletonList(FlightEndpoint.builder(ticket, Location.reuseConnection()).build());
} else if (query.equals("fallback with error")) {
endpoints = Collections.singletonList(
FlightEndpoint.builder(ticket,
Location.forGrpcInsecure("localhost", 9999),
Location.reuseConnection()).build());
} else {
throw CallStatus.UNIMPLEMENTED.withDescription(query).toRuntimeException();
}
return FlightInfo.builder(data.getSchema(), descriptor, endpoints).build();
}
}

0 comments on commit 4b14981

Please sign in to comment.