From a5ffc839c26c72edb65708a9754db97bc23106af Mon Sep 17 00:00:00 2001
From: David Li
Date: Wed, 14 Feb 2024 15:07:23 -0500
Subject: [PATCH] WIP: [FlightRPC][C++][Java] Add fallback URI scheme
Co-authored-by: Andrew Lamb
---
cpp/src/arrow/flight/flight_internals_test.cc | 5 +
.../flight_integration_test.cc | 4 +
.../integration_tests/test_integration.cc | 47 ++++++++
cpp/src/arrow/flight/types.cc | 6 +
cpp/src/arrow/flight/types.h | 8 ++
dev/archery/archery/integration/runner.py | 5 +
docs/source/format/Flight.rst | 70 +++++++++---
format/Flight.proto | 20 ++--
go/arrow/flight/server.go | 8 ++
.../internal/flight_integration/scenario.go | 52 +++++++++
.../org/apache/arrow/flight/Location.java | 13 +++
.../apache/arrow/flight/LocationSchemes.java | 1 +
.../arrow/flight/TestBasicOperation.java | 6 +
.../LocationReuseConnectionScenario.java | 66 +++++++++++
.../flight/integration/tests/Scenarios.java | 1 +
.../integration/tests/IntegrationTest.java | 65 -----------
.../integration/tests/IntegrationTest.java | 5 +
.../client/ArrowFlightSqlClientHandler.java | 7 ++
.../arrow/driver/jdbc/ResultSetTest.java | 67 +++++++++++
.../jdbc/utils/FallbackFlightSqlProducer.java | 108 ++++++++++++++++++
20 files changed, 474 insertions(+), 90 deletions(-)
create mode 100644 java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/LocationReuseConnectionScenario.java
delete mode 100644 java/flight/flight-integration-tests/src/main/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java
create mode 100644 java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FallbackFlightSqlProducer.java
diff --git a/cpp/src/arrow/flight/flight_internals_test.cc b/cpp/src/arrow/flight/flight_internals_test.cc
index a1c5250ba66fa..57f4f3e030420 100644
--- a/cpp/src/arrow/flight/flight_internals_test.cc
+++ b/cpp/src/arrow/flight/flight_internals_test.cc
@@ -353,6 +353,11 @@ TEST(FlightTypes, LocationUnknownScheme) {
ASSERT_OK(Location::Parse("https://example.com/foo"));
}
+TEST(FlightTypes, LocationFallback) {
+ EXPECT_EQ("arrow-flight-reuse-connection://?", Location::ReuseConnection().ToString());
+ EXPECT_EQ("arrow-flight-reuse-connection", Location::ReuseConnection().scheme());
+}
+
TEST(FlightTypes, RoundtripStatus) {
// Make sure status codes round trip through our conversions
diff --git a/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc b/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc
index 6f3115cc5ab8a..92c088b7fae08 100644
--- a/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc
+++ b/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc
@@ -71,6 +71,10 @@ TEST(FlightIntegration, ExpirationTimeRenewFlightEndpoint) {
ASSERT_OK(RunScenario("expiration_time:renew_flight_endpoint"));
}
+TEST(FlightIntegration, LocationReuseConnection) {
+ ASSERT_OK(RunScenario("location:reuse_connection"));
+}
+
TEST(FlightIntegration, SessionOptions) { ASSERT_OK(RunScenario("session_options")); }
TEST(FlightIntegration, PollFlightInfo) { ASSERT_OK(RunScenario("poll_flight_info")); }
diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc b/cpp/src/arrow/flight/integration_tests/test_integration.cc
index d4e0a2cda5bd8..6ba5d9c352da1 100644
--- a/cpp/src/arrow/flight/integration_tests/test_integration.cc
+++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc
@@ -2079,6 +2079,50 @@ class FlightSqlExtensionScenario : public FlightSqlScenario {
return Status::OK();
}
};
+
+/// \brief The server for testing arrow-flight-reuse-connection://.
+class ReuseConnectionServer : public FlightServerBase {
+ public:
+ Status GetFlightInfo(const ServerCallContext& context,
+ const FlightDescriptor& descriptor,
+ std::unique_ptr* info) override {
+ auto location = Location::ReuseConnection();
+ auto endpoint = FlightEndpoint{{"reuse"}, {location}};
+ ARROW_ASSIGN_OR_RAISE(auto info_data, FlightInfo::Make(arrow::Schema({}), descriptor,
+ {endpoint}, -1, -1));
+ *info = std::make_unique(std::move(info_data));
+ return Status::OK();
+ }
+};
+
+/// \brief A scenario for testing arrow-flight-reuse-connection://.
+class ReuseConnectionScenario : public Scenario {
+ Status MakeServer(std::unique_ptr* server,
+ FlightServerOptions* options) override {
+ *server = std::make_unique();
+ return Status::OK();
+ }
+
+ Status MakeClient(FlightClientOptions* options) override { return Status::OK(); }
+
+ Status RunClient(std::unique_ptr client) override {
+ auto descriptor = FlightDescriptor::Command("reuse");
+ ARROW_ASSIGN_OR_RAISE(auto info, client->GetFlightInfo(descriptor));
+ if (info->endpoints().size() != 1) {
+ return Status::Invalid("Expected 1 endpoint, got ", info->endpoints().size());
+ }
+ const auto& endpoint = info->endpoints().front();
+ if (endpoint.locations.size() != 1) {
+ return Status::Invalid("Expected 1 location, got ",
+ info->endpoints().front().locations.size());
+ } else if (endpoint.locations.front().ToString() !=
+ "arrow-flight-reuse-connection://?") {
+ return Status::Invalid("Expected arrow-flight-reuse-connection://?, got ",
+ endpoint.locations.front().ToString());
+ }
+ return Status::OK();
+ }
+};
} // namespace
Status GetScenario(const std::string& scenario_name, std::shared_ptr* out) {
@@ -2103,6 +2147,9 @@ Status GetScenario(const std::string& scenario_name, std::shared_ptr*
} else if (scenario_name == "expiration_time:renew_flight_endpoint") {
*out = std::make_shared();
return Status::OK();
+ } else if (scenario_name == "location:reuse_connection") {
+ *out = std::make_shared();
+ return Status::OK();
} else if (scenario_name == "session_options") {
*out = std::make_shared();
return Status::OK();
diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc
index 11b2baafad220..a1b799a3a069e 100644
--- a/cpp/src/arrow/flight/types.cc
+++ b/cpp/src/arrow/flight/types.cc
@@ -829,6 +829,12 @@ arrow::Result Location::Parse(const std::string& uri_string) {
return location;
}
+const Location& Location::ReuseConnection() {
+ static Location kFallback =
+ Location::Parse("arrow-flight-reuse-connection://?").ValueOrDie();
+ return kFallback;
+}
+
arrow::Result Location::ForGrpcTcp(const std::string& host, const int port) {
std::stringstream uri_string;
uri_string << "grpc+tcp://" << host << ':' << port;
diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h
index 4b17149aa2d46..c96aa428b054e 100644
--- a/cpp/src/arrow/flight/types.h
+++ b/cpp/src/arrow/flight/types.h
@@ -424,6 +424,14 @@ struct ARROW_FLIGHT_EXPORT Location {
/// \brief Initialize a location by parsing a URI string
static arrow::Result Parse(const std::string& uri_string);
+ /// \brief Get the fallback URI.
+ ///
+ /// arrow-flight-reuse-connection:// means that a client may attempt to
+ /// reuse an existing connection to a Flight service to fetch data instead
+ /// of creating a new connection to one of the other locations listed in a
+ /// FlightEndpoint response.
+ static const Location& ReuseConnection();
+
/// \brief Initialize a location for a non-TLS, gRPC-based Flight
/// service from a host and port
/// \param[in] host The hostname to connect to
diff --git a/dev/archery/archery/integration/runner.py b/dev/archery/archery/integration/runner.py
index e984468bc5052..7cdd5071b96c6 100644
--- a/dev/archery/archery/integration/runner.py
+++ b/dev/archery/archery/integration/runner.py
@@ -608,6 +608,11 @@ def run_all_tests(with_cpp=True, with_java=True, with_js=True,
"RenewFlightEndpoint are working as expected."),
skip_testers={"JS", "C#", "Rust"},
),
+ Scenario(
+ "location:reuse_connection",
+ description="Ensure arrow-flight-reuse-connection is accepted.",
+ skip_testers={"JS", "C#", "Rust"},
+ ),
Scenario(
"session_options",
description="Ensure Flight SQL Sessions work as expected.",
diff --git a/docs/source/format/Flight.rst b/docs/source/format/Flight.rst
index 73ca848b5e996..7ee84952b4350 100644
--- a/docs/source/format/Flight.rst
+++ b/docs/source/format/Flight.rst
@@ -121,6 +121,13 @@ A client that wishes to download the data would:
connection to the original server to fetch data. Otherwise, the
client must connect to one of the indicated locations.
+ The server may list "itself" as a location alongside other server
+ locations. Normally this requires the server to know its public
+ address, but it may also use the special URI string
+ ``arrow-flight-reuse-connection://?`` to tell clients that they may
+ reuse an existing connection to the same server, without having to
+ be able to name itself. See `Connection Reuse`_ below.
+
In this way, the locations inside an endpoint can also be thought
of as performing look-aside load balancing or service discovery
functions. And the endpoints can represent data that is partitioned
@@ -307,29 +314,58 @@ well, in which case any `authentication method supported by gRPC
.. _Mutual TLS (mTLS): https://grpc.io/docs/guides/auth/#supported-auth-mechanisms
-Transport Implementations
-=========================
+Location URIs
+=============
Flight is primarily defined in terms of its Protobuf and gRPC
specification below, but Arrow implementations may also support
-alternative transports (see :ref:`status-flight-rpc`). In that case,
-implementations should use the following URI schemes for the given
-transport implementations:
-
-+----------------------------+----------------------------+
-| Transport | URI Scheme |
-+============================+============================+
-| gRPC (plaintext) | grpc: or grpc+tcp: |
-+----------------------------+----------------------------+
-| gRPC (TLS) | grpc+tls: |
-+----------------------------+----------------------------+
-| gRPC (Unix domain socket) | grpc+unix: |
-+----------------------------+----------------------------+
-| UCX_ (plaintext) | ucx: |
-+----------------------------+----------------------------+
+alternative transports (see :ref:`status-flight-rpc`). Clients and
+servers need to know which transport to use for a given URI in a
+Location, so Flight implementations should use the following URI
+schemes for the given transports:
+
++----------------------------+--------------------------------+
+| Transport | URI Scheme |
++============================+================================+
+| gRPC (plaintext) | grpc: or grpc+tcp: |
++----------------------------+--------------------------------+
+| gRPC (TLS) | grpc+tls: |
++----------------------------+--------------------------------+
+| gRPC (Unix domain socket) | grpc+unix: |
++----------------------------+--------------------------------+
+| (reuse connection) | arrow-flight-reuse-connection: |
++----------------------------+--------------------------------+
+| UCX_ (plaintext) | ucx: |
++----------------------------+--------------------------------+
.. _UCX: https://openucx.org/
+Connection Reuse
+----------------
+
+"Reuse connection" above is not a particular transport. Instead, it
+means that the client may try to execute DoGet against the same server
+(and through the same connection) that it originally obtained the
+FlightInfo from (i.e., that it called GetFlightInfo against). This is
+interpreted the same way as when no specific ``Location`` are
+returned.
+
+This allows the server to return "itself" as one possible location to
+fetch data without having to know its own public address, which can be
+useful in deployments where knowing this would be difficult or
+impossible. For example, a developer may forward a remote service in
+a cloud environment to their local machine; in this case, the remote
+service would have no way to know the local hostname and port that it
+is being accessed over.
+
+For compatibility reasons, the URI should always be
+``arrow-flight-reuse-connection://?``, with the trailing empty query
+string. Java's URI implementation does not accept ``scheme:`` or
+``scheme://``, and C++'s implementation does not accept an empty
+string, so the obvious candidates are not compatible. The chosen
+representation can be parsed by both implementations, as well as Go's
+``net/url`` and Python's ``urllib.parse``.
+
Error Handling
==============
diff --git a/format/Flight.proto b/format/Flight.proto
index 59714108e1cbc..4963e8c09ae47 100644
--- a/format/Flight.proto
+++ b/format/Flight.proto
@@ -369,7 +369,7 @@ message FlightInfo {
/*
* Application-defined metadata.
- *
+ *
* There is no inherent or required relationship between this
* and the app_metadata fields in the FlightEndpoints or resulting
* FlightData messages. Since this metadata is application-defined,
@@ -440,11 +440,15 @@ message FlightEndpoint {
* be redeemed on the current service where the ticket was
* generated.
*
- * If the list is not empty, the expectation is that the ticket can
- * be redeemed at any of the locations, and that the data returned
- * will be equivalent. In this case, the ticket may only be redeemed
- * at one of the given locations, and not (necessarily) on the
- * current service.
+ * If the list is not empty, the expectation is that the ticket can be
+ * redeemed at any of the locations, and that the data returned will be
+ * equivalent. In this case, the ticket may only be redeemed at one of the
+ * given locations, and not (necessarily) on the current service. If one
+ * of the given locations is "arrow-flight-reuse-connection://?", the
+ * client may redeem the ticket on the service where the ticket was
+ * generated (i.e., the same as above), in addition to the other
+ * locations. (This URI was chosen to maximize compatibility, as 'scheme:'
+ * or 'scheme://' are not accepted by Java's java.net.URI.)
*
* In other words, an application can use multiple locations to
* represent redundant and/or load balanced services.
@@ -460,7 +464,7 @@ message FlightEndpoint {
/*
* Application-defined metadata.
- *
+ *
* There is no inherent or required relationship between this
* and the app_metadata fields in the FlightInfo or resulting
* FlightData messages. Since this metadata is application-defined,
@@ -587,7 +591,7 @@ message SetSessionOptionsResult {
message Error {
ErrorValue value = 1;
}
-
+
map errors = 1;
}
diff --git a/go/arrow/flight/server.go b/go/arrow/flight/server.go
index c70aceabcfe8e..9f76feda5c268 100644
--- a/go/arrow/flight/server.go
+++ b/go/arrow/flight/server.go
@@ -81,6 +81,14 @@ const (
CancelStatusNotCancellable = flight.CancelStatus_CANCEL_STATUS_NOT_CANCELLABLE
)
+// Constants for Location
+const (
+ // LocationReuseConnection is a special location that tells clients
+ // they may fetch the data from the same service that they obtained
+ // the FlightEndpoint response from.
+ LocationReuseConnection = "arrow-flight-reuse-connection://?"
+)
+
// RegisterFlightServiceServer registers an existing flight server onto an
// existing grpc server, or anything that is a grpc service registrar.
func RegisterFlightServiceServer(s *grpc.Server, srv FlightServer) {
diff --git a/go/arrow/internal/flight_integration/scenario.go b/go/arrow/internal/flight_integration/scenario.go
index 91658a694ecab..6c0d6d9048986 100644
--- a/go/arrow/internal/flight_integration/scenario.go
+++ b/go/arrow/internal/flight_integration/scenario.go
@@ -69,6 +69,8 @@ func GetScenario(name string, args ...string) Scenario {
return &expirationTimeCancelFlightInfoScenarioTester{}
case "expiration_time:renew_flight_endpoint":
return &expirationTimeRenewFlightEndpointScenarioTester{}
+ case "location:reuse_connection":
+ return &locationReuseConnectionScenarioTester{}
case "poll_flight_info":
return &pollFlightInfoScenarioTester{}
case "app_metadata_flight_info_endpoint":
@@ -1136,6 +1138,56 @@ func (tester *expirationTimeRenewFlightEndpointScenarioTester) RunClient(addr st
return nil
}
+type locationReuseConnectionScenarioTester struct {
+ flight.BaseFlightServer
+}
+
+func (m *locationReuseConnectionScenarioTester) GetFlightInfo(ctx context.Context, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
+ return &flight.FlightInfo{
+ Schema: flight.SerializeSchema(arrow.NewSchema([]arrow.Field{}, nil), memory.DefaultAllocator),
+ FlightDescriptor: desc,
+ Endpoint: []*flight.FlightEndpoint{{
+ Ticket: &flight.Ticket{Ticket: []byte("reuse")},
+ Location: []*flight.Location{{Uri: flight.LocationReuseConnection}},
+ }},
+ TotalRecords: -1,
+ TotalBytes: -1,
+ }, nil
+}
+
+func (tester *locationReuseConnectionScenarioTester) MakeServer(port int) flight.Server {
+ srv := flight.NewServerWithMiddleware(nil)
+ srv.RegisterFlightService(tester)
+ initServer(port, srv)
+ return srv
+}
+
+func (tester *locationReuseConnectionScenarioTester) RunClient(addr string, opts ...grpc.DialOption) error {
+ client, err := flight.NewClientWithMiddleware(addr, nil, nil, opts...)
+ if err != nil {
+ return err
+ }
+ defer client.Close()
+
+ ctx := context.Background()
+ info, err := client.GetFlightInfo(ctx, &flight.FlightDescriptor{Type: flight.DescriptorCMD, Cmd: []byte("reuse")})
+ if err != nil {
+ return err
+ }
+
+ if len(info.Endpoint) != 1 {
+ return fmt.Errorf("Expected 1 endpoint, got %d", len(info.Endpoint))
+ }
+ endpoint := info.Endpoint[0]
+ if len(endpoint.Location) != 1 {
+ return fmt.Errorf("Expected 1 location, got %d", len(endpoint.Location))
+ } else if endpoint.Location[0].Uri != flight.LocationReuseConnection {
+ return fmt.Errorf("Expected %s, got %s", flight.LocationReuseConnection, endpoint.Location[0].Uri)
+ }
+
+ return nil
+}
+
type pollFlightInfoScenarioTester struct {
flight.BaseFlightServer
}
diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Location.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Location.java
index fe192aa0c3f9d..2eb3139c9dcdd 100644
--- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Location.java
+++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Location.java
@@ -93,6 +93,19 @@ Flight.Location toProtocol() {
return Flight.Location.newBuilder().setUri(uri.toString()).build();
}
+ /**
+ * Construct a special URI to indicate to clients that they may fetch data by reusing
+ * an existing connection to a Flight RPC server.
+ */
+ public static Location reuseConnection() {
+ try {
+ return new Location(new URI(LocationSchemes.REUSE_CONNECTION, "", "", "", null));
+ } catch (URISyntaxException e) {
+ // This should never happen.
+ throw new IllegalArgumentException(e);
+ }
+ }
+
/**
* Construct a URI for a Flight+gRPC server without transport security.
*
diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/LocationSchemes.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/LocationSchemes.java
index 872e5b1c22deb..f1dbfb95f237e 100644
--- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/LocationSchemes.java
+++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/LocationSchemes.java
@@ -25,6 +25,7 @@ public final class LocationSchemes {
public static final String GRPC_INSECURE = "grpc+tcp";
public static final String GRPC_DOMAIN_SOCKET = "grpc+unix";
public static final String GRPC_TLS = "grpc+tls";
+ public static final String REUSE_CONNECTION = "arrow-flight-reuse-connection";
private LocationSchemes() {
throw new AssertionError("Do not instantiate this class.");
diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java
index ae520ee9b991b..bc34b5e6d6074 100644
--- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java
+++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java
@@ -79,6 +79,12 @@ public void fastPathDefaults() {
Assertions.assertFalse(ArrowMessage.ENABLE_ZERO_COPY_WRITE);
}
+ @Test
+ public void fallbackLocation() {
+ Assertions.assertEquals("arrow-flight-reuse-connection://?",
+ Location.reuseConnection().getUri().toString());
+ }
+
/**
* ARROW-6017: we should be able to construct locations for unknown schemes.
*/
diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/LocationReuseConnectionScenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/LocationReuseConnectionScenario.java
new file mode 100644
index 0000000000000..7ffd09c5bf4e8
--- /dev/null
+++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/LocationReuseConnectionScenario.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.integration.tests;
+
+import java.nio.charset.StandardCharsets;
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.arrow.flight.FlightClient;
+import org.apache.arrow.flight.FlightDescriptor;
+import org.apache.arrow.flight.FlightEndpoint;
+import org.apache.arrow.flight.FlightInfo;
+import org.apache.arrow.flight.FlightProducer;
+import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.flight.NoOpFlightProducer;
+import org.apache.arrow.flight.Ticket;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.vector.types.pojo.Schema;
+
+/** Test the 'arrow-flight-reuse-connection' scheme. */
+public class LocationReuseConnectionScenario implements Scenario {
+ @Override
+ public FlightProducer producer(BufferAllocator allocator, Location location) throws Exception {
+ return new ReuseConnectionProducer();
+ }
+
+ @Override
+ public void buildServer(FlightServer.Builder builder) throws Exception {
+ }
+
+ @Override
+ public void client(BufferAllocator allocator, Location location, FlightClient client)
+ throws Exception {
+ final FlightInfo info = client.getInfo(FlightDescriptor.command("reuse".getBytes(StandardCharsets.UTF_8)));
+ IntegrationAssertions.assertEquals(1, info.getEndpoints().size());
+ IntegrationAssertions.assertEquals(1, info.getEndpoints().get(0).getLocations().size());
+ Location actual = info.getEndpoints().get(0).getLocations().get(0);
+ IntegrationAssertions.assertEquals(Location.reuseConnection().getUri(), actual.getUri());
+ }
+
+ private static class ReuseConnectionProducer extends NoOpFlightProducer {
+ @Override
+ public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) {
+ List endpoints = Collections.singletonList(
+ new FlightEndpoint(new Ticket(new byte[0]), Location.reuseConnection()));
+ return new FlightInfo(
+ new Schema(Collections.emptyList()), descriptor, endpoints, /*bytes*/ -1, /*records*/ -1);
+ }
+ }
+}
diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java
index 6878c22c5ccdc..5ce82e712ab77 100644
--- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java
+++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java
@@ -44,6 +44,7 @@ private Scenarios() {
scenarios.put("expiration_time:renew_flight_endpoint", ExpirationTimeRenewFlightEndpointScenario::new);
scenarios.put("expiration_time:do_get", ExpirationTimeDoGetScenario::new);
scenarios.put("expiration_time:list_actions", ExpirationTimeListActionsScenario::new);
+ scenarios.put("location:reuse_connection", LocationReuseConnectionScenario::new);
scenarios.put("middleware", MiddlewareScenario::new);
scenarios.put("ordered", OrderedScenario::new);
scenarios.put("poll_flight_info", PollFlightInfoScenario::new);
diff --git a/java/flight/flight-integration-tests/src/main/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java b/java/flight/flight-integration-tests/src/main/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java
deleted file mode 100644
index dfb9a810857ba..0000000000000
--- a/java/flight/flight-integration-tests/src/main/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java
+++ /dev/null
@@ -1,65 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.arrow.flight.integration.tests;
-
-import org.apache.arrow.flight.FlightClient;
-import org.apache.arrow.flight.FlightServer;
-import org.apache.arrow.flight.Location;
-import org.apache.arrow.memory.BufferAllocator;
-import org.apache.arrow.memory.RootAllocator;
-import org.junit.jupiter.api.Test;
-
-/**
- * Run the integration test scenarios in-process.
- */
-class IntegrationTest {
- @Test
- void authBasicProto() throws Exception {
- testScenario("auth:basic_proto");
- }
-
- @Test
- void middleware() throws Exception {
- testScenario("middleware");
- }
-
- @Test
- void flightSql() throws Exception {
- testScenario("flight_sql");
- }
-
- void testScenario(String scenarioName) throws Exception {
- try (final BufferAllocator allocator = new RootAllocator()) {
- final FlightServer.Builder builder = FlightServer.builder()
- .allocator(allocator)
- .location(Location.forGrpcInsecure("0.0.0.0", 0));
- final Scenario scenario = Scenarios.getScenario(scenarioName);
- scenario.buildServer(builder);
- builder.producer(scenario.producer(allocator, Location.forGrpcInsecure("0.0.0.0", 0)));
-
- try (final FlightServer server = builder.build()) {
- server.start();
-
- final Location location = Location.forGrpcInsecure("localhost", server.getPort());
- try (final FlightClient client = FlightClient.builder(allocator, location).build()) {
- scenario.client(allocator, location, client);
- }
- }
- }
- }
-}
diff --git a/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java b/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java
index f814427567ae9..40aceb336e0ea 100644
--- a/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java
+++ b/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java
@@ -53,6 +53,11 @@ void expirationTimeRenewFlightEndpoint() throws Exception {
testScenario("expiration_time:renew_flight_endpoint");
}
+ @Test
+ void locationReuseConnection() throws Exception {
+ testScenario("location:reuse_connection");
+ }
+
@Test
void middleware() throws Exception {
testScenario("middleware");
diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java
index 234820bd41823..a47cffc2fcf09 100644
--- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java
+++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java
@@ -124,6 +124,13 @@ public List getStreams(final FlightInfo flightInfo)
// It would also be good to identify when the reported location is the same as the original connection's
// Location and skip creating a FlightClient in that scenario.
final URI endpointUri = endpoint.getLocations().get(0).getUri();
+
+ if (endpointUri.getScheme().equals(LocationSchemes.REUSE_CONNECTION)) {
+ endpoints.add(new CloseableEndpointStreamPair(
+ sqlClient.getStream(endpoint.getTicket(), getOptions()), null));
+ continue;
+ }
+
final Builder builderForEndpoint = new Builder(ArrowFlightSqlClientHandler.this.builder)
.withHost(endpointUri.getHost())
.withPort(endpointUri.getPort())
diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java
index 0e3e015a04636..ad01e8767b793 100644
--- a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java
+++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java
@@ -39,6 +39,7 @@
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
@@ -46,6 +47,7 @@
import java.util.concurrent.CountDownLatch;
import org.apache.arrow.driver.jdbc.utils.CoreMockedSqlProducers;
+import org.apache.arrow.driver.jdbc.utils.FallbackFlightSqlProducer;
import org.apache.arrow.driver.jdbc.utils.PartitionedFlightSqlProducer;
import org.apache.arrow.flight.FlightEndpoint;
import org.apache.arrow.flight.FlightProducer;
@@ -55,6 +57,7 @@
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
@@ -63,6 +66,7 @@
import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.Test;
+import org.junit.jupiter.api.Assertions;
import org.junit.rules.ErrorCollector;
import com.google.common.collect.ImmutableSet;
@@ -455,6 +459,69 @@ allocator, forGrpcInsecure("localhost", 0), rootProducer)
}
}
+ @Test
+ public void testFallbackFlightServer() throws Exception {
+ final Schema schema = new Schema(
+ Collections.singletonList(Field.nullable("int_column", Types.MinorType.INT.getType())));
+ try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ VectorSchemaRoot resultData = VectorSchemaRoot.create(schema, allocator)) {
+ resultData.setRowCount(1);
+ ((IntVector) resultData.getVector(0)).set(0, 1);
+
+ try (final FallbackFlightSqlProducer rootProducer = new FallbackFlightSqlProducer(resultData);
+ FlightServer rootServer = FlightServer.builder(
+ allocator, forGrpcInsecure("localhost", 0), rootProducer)
+ .build()
+ .start();
+ Connection newConnection = DriverManager.getConnection(String.format(
+ "jdbc:arrow-flight-sql://%s:%d/?useEncryption=false",
+ rootServer.getLocation().getUri().getHost(), rootServer.getPort()));
+ Statement newStatement = newConnection.createStatement();
+ ResultSet result = newStatement.executeQuery("fallback")) {
+ List actualData = new ArrayList<>();
+ while (result.next()) {
+ actualData.add(result.getInt(1));
+ }
+
+ // Assert
+ assertEquals(resultData.getRowCount(), actualData.size());
+ assertTrue(actualData.contains(((IntVector) resultData.getVector(0)).get(0)));
+ }
+ }
+ }
+
+ @Test
+ public void testFallbackSecondFlightServer() throws Exception {
+ final Schema schema = new Schema(
+ Collections.singletonList(Field.nullable("int_column", Types.MinorType.INT.getType())));
+ try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ VectorSchemaRoot resultData = VectorSchemaRoot.create(schema, allocator)) {
+ resultData.setRowCount(1);
+ ((IntVector) resultData.getVector(0)).set(0, 1);
+
+ try (final FallbackFlightSqlProducer rootProducer = new FallbackFlightSqlProducer(resultData);
+ FlightServer rootServer = FlightServer.builder(
+ allocator, forGrpcInsecure("localhost", 0), rootProducer)
+ .build()
+ .start();
+ Connection newConnection = DriverManager.getConnection(String.format(
+ "jdbc:arrow-flight-sql://%s:%d/?useEncryption=false",
+ rootServer.getLocation().getUri().getHost(), rootServer.getPort()));
+ Statement newStatement = newConnection.createStatement()) {
+
+ // TODO(https://github.com/apache/arrow/issues/38573)
+ // XXX: we could try to assert more structure but then we'd have to hardcode
+ // a particular exception chain which may be fragile
+ Assertions.assertThrows(SQLException.class, () -> {
+ try (ResultSet result = newStatement.executeQuery("fallback with error")) {
+ // Empty body
+ }
+ });
+
+ }
+ }
+ }
+
@Test
public void testShouldRunSelectQueryWithEmptyVectorsEmbedded() throws Exception {
try (Statement statement = connection.createStatement();
diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FallbackFlightSqlProducer.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FallbackFlightSqlProducer.java
new file mode 100644
index 0000000000000..2257220a4c845
--- /dev/null
+++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FallbackFlightSqlProducer.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.driver.jdbc.utils;
+
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.arrow.flight.CallStatus;
+import org.apache.arrow.flight.FlightDescriptor;
+import org.apache.arrow.flight.FlightEndpoint;
+import org.apache.arrow.flight.FlightInfo;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.flight.Result;
+import org.apache.arrow.flight.Ticket;
+import org.apache.arrow.flight.sql.BasicFlightSqlProducer;
+import org.apache.arrow.flight.sql.impl.FlightSql;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.types.pojo.Schema;
+
+import com.google.protobuf.Any;
+import com.google.protobuf.ByteString;
+import com.google.protobuf.Message;
+
+public class FallbackFlightSqlProducer extends BasicFlightSqlProducer {
+ private final VectorSchemaRoot data;
+
+ public FallbackFlightSqlProducer(VectorSchemaRoot resultData) {
+ this.data = resultData;
+ }
+
+ @Override
+ protected List determineEndpoints(
+ T request, FlightDescriptor flightDescriptor, Schema schema) {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public void createPreparedStatement(FlightSql.ActionCreatePreparedStatementRequest request,
+ CallContext context, StreamListener listener) {
+ final FlightSql.ActionCreatePreparedStatementResult.Builder resultBuilder =
+ FlightSql.ActionCreatePreparedStatementResult.newBuilder()
+ .setPreparedStatementHandle(request.getQueryBytes());
+
+ final ByteString datasetSchemaBytes = ByteString.copyFrom(data.getSchema().serializeAsMessage());
+
+ resultBuilder.setDatasetSchema(datasetSchemaBytes);
+ listener.onNext(new Result(Any.pack(resultBuilder.build()).toByteArray()));
+ listener.onCompleted();
+ }
+
+ @Override
+ public FlightInfo getFlightInfoStatement(
+ FlightSql.CommandStatementQuery command, CallContext context, FlightDescriptor descriptor) {
+ return getFlightInfo(descriptor, command.getQuery());
+ }
+
+ @Override
+ public FlightInfo getFlightInfoPreparedStatement(FlightSql.CommandPreparedStatementQuery command,
+ CallContext context, FlightDescriptor descriptor) {
+ return getFlightInfo(descriptor, command.getPreparedStatementHandle().toStringUtf8());
+ }
+
+ @Override
+ public void getStreamStatement(FlightSql.TicketStatementQuery ticket, CallContext context,
+ ServerStreamListener listener) {
+ listener.start(data);
+ listener.putNext();
+ listener.completed();
+ }
+
+ @Override
+ public void closePreparedStatement(FlightSql.ActionClosePreparedStatementRequest request,
+ CallContext context, StreamListener listener) {
+ listener.onCompleted();
+ }
+
+ private FlightInfo getFlightInfo(FlightDescriptor descriptor, String query) {
+ final List endpoints;
+ final Ticket ticket = new Ticket(
+ Any.pack(FlightSql.TicketStatementQuery.getDefaultInstance()).toByteArray());
+ if (query.equals("fallback")) {
+ endpoints = Collections.singletonList(FlightEndpoint.builder(ticket, Location.reuseConnection()).build());
+ } else if (query.equals("fallback with error")) {
+ endpoints = Collections.singletonList(
+ FlightEndpoint.builder(ticket,
+ Location.forGrpcInsecure("localhost", 9999),
+ Location.reuseConnection()).build());
+ } else {
+ throw CallStatus.UNIMPLEMENTED.withDescription(query).toRuntimeException();
+ }
+ return FlightInfo.builder(data.getSchema(), descriptor, endpoints).build();
+ }
+}