From 2b0d38f9f780250bd6b24c8430b101103186b9cb Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 14 Feb 2024 15:07:23 -0500 Subject: [PATCH] WIP: [FlightRPC][C++][Java] Add fallback URI scheme Co-authored-by: Andrew Lamb --- cpp/src/arrow/flight/flight_internals_test.cc | 5 + .../flight_integration_test.cc | 4 + .../integration_tests/test_integration.cc | 47 ++++++++ cpp/src/arrow/flight/types.cc | 6 + cpp/src/arrow/flight/types.h | 8 ++ dev/archery/archery/integration/runner.py | 5 + docs/source/format/Flight.rst | 70 +++++++++--- format/Flight.proto | 20 ++-- go/arrow/flight/server.go | 8 ++ .../internal/flight_integration/scenario.go | 52 +++++++++ .../org/apache/arrow/flight/Location.java | 13 +++ .../apache/arrow/flight/LocationSchemes.java | 1 + .../arrow/flight/TestBasicOperation.java | 6 + .../LocationReuseConnectionScenario.java | 66 +++++++++++ .../flight/integration/tests/Scenarios.java | 1 + .../integration/tests/IntegrationTest.java | 65 ----------- .../integration/tests/IntegrationTest.java | 5 + .../client/ArrowFlightSqlClientHandler.java | 7 ++ .../arrow/driver/jdbc/ResultSetTest.java | 67 +++++++++++ .../jdbc/utils/FallbackFlightSqlProducer.java | 108 ++++++++++++++++++ 20 files changed, 474 insertions(+), 90 deletions(-) create mode 100644 java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/LocationReuseConnectionScenario.java delete mode 100644 java/flight/flight-integration-tests/src/main/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java create mode 100644 java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FallbackFlightSqlProducer.java diff --git a/cpp/src/arrow/flight/flight_internals_test.cc b/cpp/src/arrow/flight/flight_internals_test.cc index a1c5250ba66fa..57f4f3e030420 100644 --- a/cpp/src/arrow/flight/flight_internals_test.cc +++ b/cpp/src/arrow/flight/flight_internals_test.cc @@ -353,6 +353,11 @@ TEST(FlightTypes, LocationUnknownScheme) { ASSERT_OK(Location::Parse("https://example.com/foo")); } +TEST(FlightTypes, LocationFallback) { + EXPECT_EQ("arrow-flight-reuse-connection://?", Location::ReuseConnection().ToString()); + EXPECT_EQ("arrow-flight-reuse-connection", Location::ReuseConnection().scheme()); +} + TEST(FlightTypes, RoundtripStatus) { // Make sure status codes round trip through our conversions diff --git a/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc b/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc index 6f3115cc5ab8a..92c088b7fae08 100644 --- a/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc +++ b/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc @@ -71,6 +71,10 @@ TEST(FlightIntegration, ExpirationTimeRenewFlightEndpoint) { ASSERT_OK(RunScenario("expiration_time:renew_flight_endpoint")); } +TEST(FlightIntegration, LocationReuseConnection) { + ASSERT_OK(RunScenario("location:reuse_connection")); +} + TEST(FlightIntegration, SessionOptions) { ASSERT_OK(RunScenario("session_options")); } TEST(FlightIntegration, PollFlightInfo) { ASSERT_OK(RunScenario("poll_flight_info")); } diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc b/cpp/src/arrow/flight/integration_tests/test_integration.cc index d4e0a2cda5bd8..6ba5d9c352da1 100644 --- a/cpp/src/arrow/flight/integration_tests/test_integration.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc @@ -2079,6 +2079,50 @@ class FlightSqlExtensionScenario : public FlightSqlScenario { return Status::OK(); } }; + +/// \brief The server for testing arrow-flight-reuse-connection://. +class ReuseConnectionServer : public FlightServerBase { + public: + Status GetFlightInfo(const ServerCallContext& context, + const FlightDescriptor& descriptor, + std::unique_ptr* info) override { + auto location = Location::ReuseConnection(); + auto endpoint = FlightEndpoint{{"reuse"}, {location}}; + ARROW_ASSIGN_OR_RAISE(auto info_data, FlightInfo::Make(arrow::Schema({}), descriptor, + {endpoint}, -1, -1)); + *info = std::make_unique(std::move(info_data)); + return Status::OK(); + } +}; + +/// \brief A scenario for testing arrow-flight-reuse-connection://. +class ReuseConnectionScenario : public Scenario { + Status MakeServer(std::unique_ptr* server, + FlightServerOptions* options) override { + *server = std::make_unique(); + return Status::OK(); + } + + Status MakeClient(FlightClientOptions* options) override { return Status::OK(); } + + Status RunClient(std::unique_ptr client) override { + auto descriptor = FlightDescriptor::Command("reuse"); + ARROW_ASSIGN_OR_RAISE(auto info, client->GetFlightInfo(descriptor)); + if (info->endpoints().size() != 1) { + return Status::Invalid("Expected 1 endpoint, got ", info->endpoints().size()); + } + const auto& endpoint = info->endpoints().front(); + if (endpoint.locations.size() != 1) { + return Status::Invalid("Expected 1 location, got ", + info->endpoints().front().locations.size()); + } else if (endpoint.locations.front().ToString() != + "arrow-flight-reuse-connection://?") { + return Status::Invalid("Expected arrow-flight-reuse-connection://?, got ", + endpoint.locations.front().ToString()); + } + return Status::OK(); + } +}; } // namespace Status GetScenario(const std::string& scenario_name, std::shared_ptr* out) { @@ -2103,6 +2147,9 @@ Status GetScenario(const std::string& scenario_name, std::shared_ptr* } else if (scenario_name == "expiration_time:renew_flight_endpoint") { *out = std::make_shared(); return Status::OK(); + } else if (scenario_name == "location:reuse_connection") { + *out = std::make_shared(); + return Status::OK(); } else if (scenario_name == "session_options") { *out = std::make_shared(); return Status::OK(); diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc index 11b2baafad220..a1b799a3a069e 100644 --- a/cpp/src/arrow/flight/types.cc +++ b/cpp/src/arrow/flight/types.cc @@ -829,6 +829,12 @@ arrow::Result Location::Parse(const std::string& uri_string) { return location; } +const Location& Location::ReuseConnection() { + static Location kFallback = + Location::Parse("arrow-flight-reuse-connection://?").ValueOrDie(); + return kFallback; +} + arrow::Result Location::ForGrpcTcp(const std::string& host, const int port) { std::stringstream uri_string; uri_string << "grpc+tcp://" << host << ':' << port; diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index 4b17149aa2d46..c96aa428b054e 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -424,6 +424,14 @@ struct ARROW_FLIGHT_EXPORT Location { /// \brief Initialize a location by parsing a URI string static arrow::Result Parse(const std::string& uri_string); + /// \brief Get the fallback URI. + /// + /// arrow-flight-reuse-connection:// means that a client may attempt to + /// reuse an existing connection to a Flight service to fetch data instead + /// of creating a new connection to one of the other locations listed in a + /// FlightEndpoint response. + static const Location& ReuseConnection(); + /// \brief Initialize a location for a non-TLS, gRPC-based Flight /// service from a host and port /// \param[in] host The hostname to connect to diff --git a/dev/archery/archery/integration/runner.py b/dev/archery/archery/integration/runner.py index e984468bc5052..7cdd5071b96c6 100644 --- a/dev/archery/archery/integration/runner.py +++ b/dev/archery/archery/integration/runner.py @@ -608,6 +608,11 @@ def run_all_tests(with_cpp=True, with_java=True, with_js=True, "RenewFlightEndpoint are working as expected."), skip_testers={"JS", "C#", "Rust"}, ), + Scenario( + "location:reuse_connection", + description="Ensure arrow-flight-reuse-connection is accepted.", + skip_testers={"JS", "C#", "Rust"}, + ), Scenario( "session_options", description="Ensure Flight SQL Sessions work as expected.", diff --git a/docs/source/format/Flight.rst b/docs/source/format/Flight.rst index 73ca848b5e996..7ee84952b4350 100644 --- a/docs/source/format/Flight.rst +++ b/docs/source/format/Flight.rst @@ -121,6 +121,13 @@ A client that wishes to download the data would: connection to the original server to fetch data. Otherwise, the client must connect to one of the indicated locations. + The server may list "itself" as a location alongside other server + locations. Normally this requires the server to know its public + address, but it may also use the special URI string + ``arrow-flight-reuse-connection://?`` to tell clients that they may + reuse an existing connection to the same server, without having to + be able to name itself. See `Connection Reuse`_ below. + In this way, the locations inside an endpoint can also be thought of as performing look-aside load balancing or service discovery functions. And the endpoints can represent data that is partitioned @@ -307,29 +314,58 @@ well, in which case any `authentication method supported by gRPC .. _Mutual TLS (mTLS): https://grpc.io/docs/guides/auth/#supported-auth-mechanisms -Transport Implementations -========================= +Location URIs +============= Flight is primarily defined in terms of its Protobuf and gRPC specification below, but Arrow implementations may also support -alternative transports (see :ref:`status-flight-rpc`). In that case, -implementations should use the following URI schemes for the given -transport implementations: - -+----------------------------+----------------------------+ -| Transport | URI Scheme | -+============================+============================+ -| gRPC (plaintext) | grpc: or grpc+tcp: | -+----------------------------+----------------------------+ -| gRPC (TLS) | grpc+tls: | -+----------------------------+----------------------------+ -| gRPC (Unix domain socket) | grpc+unix: | -+----------------------------+----------------------------+ -| UCX_ (plaintext) | ucx: | -+----------------------------+----------------------------+ +alternative transports (see :ref:`status-flight-rpc`). Clients and +servers need to know which transport to use for a given URI in a +Location, so Flight implementations should use the following URI +schemes for the given transports: + ++----------------------------+--------------------------------+ +| Transport | URI Scheme | ++============================+================================+ +| gRPC (plaintext) | grpc: or grpc+tcp: | ++----------------------------+--------------------------------+ +| gRPC (TLS) | grpc+tls: | ++----------------------------+--------------------------------+ +| gRPC (Unix domain socket) | grpc+unix: | ++----------------------------+--------------------------------+ +| (reuse connection) | arrow-flight-reuse-connection: | ++----------------------------+--------------------------------+ +| UCX_ (plaintext) | ucx: | ++----------------------------+--------------------------------+ .. _UCX: https://openucx.org/ +Connection Reuse +---------------- + +"Reuse connection" above is not a particular transport. Instead, it +means that the client may try to execute DoGet against the same server +(and through the same connection) that it originally obtained the +FlightInfo from (i.e., that it called GetFlightInfo against). This is +interpreted the same way as when no specific ``Location`` are +returned. + +This allows the server to return "itself" as one possible location to +fetch data without having to know its own public address, which can be +useful in deployments where knowing this would be difficult or +impossible. For example, a developer may forward a remote service in +a cloud environment to their local machine; in this case, the remote +service would have no way to know the local hostname and port that it +is being accessed over. + +For compatibility reasons, the URI should always be +``arrow-flight-reuse-connection://?``, with the trailing empty query +string. Java's URI implementation does not accept ``scheme:`` or +``scheme://``, and C++'s implementation does not accept an empty +string, so the obvious candidates are not compatible. The chosen +representation can be parsed by both implementations, as well as Go's +``net/url`` and Python's ``urllib.parse``. + Error Handling ============== diff --git a/format/Flight.proto b/format/Flight.proto index 59714108e1cbc..4963e8c09ae47 100644 --- a/format/Flight.proto +++ b/format/Flight.proto @@ -369,7 +369,7 @@ message FlightInfo { /* * Application-defined metadata. - * + * * There is no inherent or required relationship between this * and the app_metadata fields in the FlightEndpoints or resulting * FlightData messages. Since this metadata is application-defined, @@ -440,11 +440,15 @@ message FlightEndpoint { * be redeemed on the current service where the ticket was * generated. * - * If the list is not empty, the expectation is that the ticket can - * be redeemed at any of the locations, and that the data returned - * will be equivalent. In this case, the ticket may only be redeemed - * at one of the given locations, and not (necessarily) on the - * current service. + * If the list is not empty, the expectation is that the ticket can be + * redeemed at any of the locations, and that the data returned will be + * equivalent. In this case, the ticket may only be redeemed at one of the + * given locations, and not (necessarily) on the current service. If one + * of the given locations is "arrow-flight-reuse-connection://?", the + * client may redeem the ticket on the service where the ticket was + * generated (i.e., the same as above), in addition to the other + * locations. (This URI was chosen to maximize compatibility, as 'scheme:' + * or 'scheme://' are not accepted by Java's java.net.URI.) * * In other words, an application can use multiple locations to * represent redundant and/or load balanced services. @@ -460,7 +464,7 @@ message FlightEndpoint { /* * Application-defined metadata. - * + * * There is no inherent or required relationship between this * and the app_metadata fields in the FlightInfo or resulting * FlightData messages. Since this metadata is application-defined, @@ -587,7 +591,7 @@ message SetSessionOptionsResult { message Error { ErrorValue value = 1; } - + map errors = 1; } diff --git a/go/arrow/flight/server.go b/go/arrow/flight/server.go index c70aceabcfe8e..9f76feda5c268 100644 --- a/go/arrow/flight/server.go +++ b/go/arrow/flight/server.go @@ -81,6 +81,14 @@ const ( CancelStatusNotCancellable = flight.CancelStatus_CANCEL_STATUS_NOT_CANCELLABLE ) +// Constants for Location +const ( + // LocationReuseConnection is a special location that tells clients + // they may fetch the data from the same service that they obtained + // the FlightEndpoint response from. + LocationReuseConnection = "arrow-flight-reuse-connection://?" +) + // RegisterFlightServiceServer registers an existing flight server onto an // existing grpc server, or anything that is a grpc service registrar. func RegisterFlightServiceServer(s *grpc.Server, srv FlightServer) { diff --git a/go/arrow/internal/flight_integration/scenario.go b/go/arrow/internal/flight_integration/scenario.go index 91658a694ecab..6c0d6d9048986 100644 --- a/go/arrow/internal/flight_integration/scenario.go +++ b/go/arrow/internal/flight_integration/scenario.go @@ -69,6 +69,8 @@ func GetScenario(name string, args ...string) Scenario { return &expirationTimeCancelFlightInfoScenarioTester{} case "expiration_time:renew_flight_endpoint": return &expirationTimeRenewFlightEndpointScenarioTester{} + case "location:reuse_connection": + return &locationReuseConnectionScenarioTester{} case "poll_flight_info": return &pollFlightInfoScenarioTester{} case "app_metadata_flight_info_endpoint": @@ -1136,6 +1138,56 @@ func (tester *expirationTimeRenewFlightEndpointScenarioTester) RunClient(addr st return nil } +type locationReuseConnectionScenarioTester struct { + flight.BaseFlightServer +} + +func (m *locationReuseConnectionScenarioTester) GetFlightInfo(ctx context.Context, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + return &flight.FlightInfo{ + Schema: flight.SerializeSchema(arrow.NewSchema([]arrow.Field{}, nil), memory.DefaultAllocator), + FlightDescriptor: desc, + Endpoint: []*flight.FlightEndpoint{{ + Ticket: &flight.Ticket{Ticket: []byte("reuse")}, + Location: []*flight.Location{{Uri: flight.LocationReuseConnection}}, + }}, + TotalRecords: -1, + TotalBytes: -1, + }, nil +} + +func (tester *locationReuseConnectionScenarioTester) MakeServer(port int) flight.Server { + srv := flight.NewServerWithMiddleware(nil) + srv.RegisterFlightService(tester) + initServer(port, srv) + return srv +} + +func (tester *locationReuseConnectionScenarioTester) RunClient(addr string, opts ...grpc.DialOption) error { + client, err := flight.NewClientWithMiddleware(addr, nil, nil, opts...) + if err != nil { + return err + } + defer client.Close() + + ctx := context.Background() + info, err := client.GetFlightInfo(ctx, &flight.FlightDescriptor{Type: flight.DescriptorCMD, Cmd: []byte("reuse")}) + if err != nil { + return err + } + + if len(info.Endpoint) != 1 { + return fmt.Errorf("Expected 1 endpoint, got %d", len(info.Endpoint)) + } + endpoint := info.Endpoint[0] + if len(endpoint.Location) != 1 { + return fmt.Errorf("Expected 1 location, got %d", len(endpoint.Location)) + } else if endpoint.Location[0].Uri != flight.LocationReuseConnection { + return fmt.Errorf("Expected %s, got %s", flight.LocationReuseConnection, endpoint.Location[0].Uri) + } + + return nil +} + type pollFlightInfoScenarioTester struct { flight.BaseFlightServer } diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Location.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Location.java index fe192aa0c3f9d..2eb3139c9dcdd 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Location.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Location.java @@ -93,6 +93,19 @@ Flight.Location toProtocol() { return Flight.Location.newBuilder().setUri(uri.toString()).build(); } + /** + * Construct a special URI to indicate to clients that they may fetch data by reusing + * an existing connection to a Flight RPC server. + */ + public static Location reuseConnection() { + try { + return new Location(new URI(LocationSchemes.REUSE_CONNECTION, "", "", "", null)); + } catch (URISyntaxException e) { + // This should never happen. + throw new IllegalArgumentException(e); + } + } + /** * Construct a URI for a Flight+gRPC server without transport security. * diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/LocationSchemes.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/LocationSchemes.java index 872e5b1c22deb..f1dbfb95f237e 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/LocationSchemes.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/LocationSchemes.java @@ -25,6 +25,7 @@ public final class LocationSchemes { public static final String GRPC_INSECURE = "grpc+tcp"; public static final String GRPC_DOMAIN_SOCKET = "grpc+unix"; public static final String GRPC_TLS = "grpc+tls"; + public static final String REUSE_CONNECTION = "arrow-flight-reuse-connection"; private LocationSchemes() { throw new AssertionError("Do not instantiate this class."); diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java index ae520ee9b991b..bc34b5e6d6074 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java @@ -79,6 +79,12 @@ public void fastPathDefaults() { Assertions.assertFalse(ArrowMessage.ENABLE_ZERO_COPY_WRITE); } + @Test + public void fallbackLocation() { + Assertions.assertEquals("arrow-flight-reuse-connection://?", + Location.reuseConnection().getUri().toString()); + } + /** * ARROW-6017: we should be able to construct locations for unknown schemes. */ diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/LocationReuseConnectionScenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/LocationReuseConnectionScenario.java new file mode 100644 index 0000000000000..7ffd09c5bf4e8 --- /dev/null +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/LocationReuseConnectionScenario.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.integration.tests; + +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.List; + +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.NoOpFlightProducer; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.types.pojo.Schema; + +/** Test the 'arrow-flight-reuse-connection' scheme. */ +public class LocationReuseConnectionScenario implements Scenario { + @Override + public FlightProducer producer(BufferAllocator allocator, Location location) throws Exception { + return new ReuseConnectionProducer(); + } + + @Override + public void buildServer(FlightServer.Builder builder) throws Exception { + } + + @Override + public void client(BufferAllocator allocator, Location location, FlightClient client) + throws Exception { + final FlightInfo info = client.getInfo(FlightDescriptor.command("reuse".getBytes(StandardCharsets.UTF_8))); + IntegrationAssertions.assertEquals(1, info.getEndpoints().size()); + IntegrationAssertions.assertEquals(1, info.getEndpoints().get(0).getLocations().size()); + Location actual = info.getEndpoints().get(0).getLocations().get(0); + IntegrationAssertions.assertEquals(Location.reuseConnection().getUri(), actual.getUri()); + } + + private static class ReuseConnectionProducer extends NoOpFlightProducer { + @Override + public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) { + List endpoints = Collections.singletonList( + new FlightEndpoint(new Ticket(new byte[0]), Location.reuseConnection())); + return new FlightInfo( + new Schema(Collections.emptyList()), descriptor, endpoints, /*bytes*/ -1, /*records*/ -1); + } + } +} diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java index 6878c22c5ccdc..5ce82e712ab77 100644 --- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java @@ -44,6 +44,7 @@ private Scenarios() { scenarios.put("expiration_time:renew_flight_endpoint", ExpirationTimeRenewFlightEndpointScenario::new); scenarios.put("expiration_time:do_get", ExpirationTimeDoGetScenario::new); scenarios.put("expiration_time:list_actions", ExpirationTimeListActionsScenario::new); + scenarios.put("location:reuse_connection", LocationReuseConnectionScenario::new); scenarios.put("middleware", MiddlewareScenario::new); scenarios.put("ordered", OrderedScenario::new); scenarios.put("poll_flight_info", PollFlightInfoScenario::new); diff --git a/java/flight/flight-integration-tests/src/main/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java b/java/flight/flight-integration-tests/src/main/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java deleted file mode 100644 index dfb9a810857ba..0000000000000 --- a/java/flight/flight-integration-tests/src/main/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.arrow.flight.integration.tests; - -import org.apache.arrow.flight.FlightClient; -import org.apache.arrow.flight.FlightServer; -import org.apache.arrow.flight.Location; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.RootAllocator; -import org.junit.jupiter.api.Test; - -/** - * Run the integration test scenarios in-process. - */ -class IntegrationTest { - @Test - void authBasicProto() throws Exception { - testScenario("auth:basic_proto"); - } - - @Test - void middleware() throws Exception { - testScenario("middleware"); - } - - @Test - void flightSql() throws Exception { - testScenario("flight_sql"); - } - - void testScenario(String scenarioName) throws Exception { - try (final BufferAllocator allocator = new RootAllocator()) { - final FlightServer.Builder builder = FlightServer.builder() - .allocator(allocator) - .location(Location.forGrpcInsecure("0.0.0.0", 0)); - final Scenario scenario = Scenarios.getScenario(scenarioName); - scenario.buildServer(builder); - builder.producer(scenario.producer(allocator, Location.forGrpcInsecure("0.0.0.0", 0))); - - try (final FlightServer server = builder.build()) { - server.start(); - - final Location location = Location.forGrpcInsecure("localhost", server.getPort()); - try (final FlightClient client = FlightClient.builder(allocator, location).build()) { - scenario.client(allocator, location, client); - } - } - } - } -} diff --git a/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java b/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java index f814427567ae9..40aceb336e0ea 100644 --- a/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java +++ b/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java @@ -53,6 +53,11 @@ void expirationTimeRenewFlightEndpoint() throws Exception { testScenario("expiration_time:renew_flight_endpoint"); } + @Test + void locationReuseConnection() throws Exception { + testScenario("location:reuse_connection"); + } + @Test void middleware() throws Exception { testScenario("middleware"); diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java index 234820bd41823..a47cffc2fcf09 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java @@ -124,6 +124,13 @@ public List getStreams(final FlightInfo flightInfo) // It would also be good to identify when the reported location is the same as the original connection's // Location and skip creating a FlightClient in that scenario. final URI endpointUri = endpoint.getLocations().get(0).getUri(); + + if (endpointUri.getScheme().equals(LocationSchemes.REUSE_CONNECTION)) { + endpoints.add(new CloseableEndpointStreamPair( + sqlClient.getStream(endpoint.getTicket(), getOptions()), null)); + continue; + } + final Builder builderForEndpoint = new Builder(ArrowFlightSqlClientHandler.this.builder) .withHost(endpointUri.getHost()) .withPort(endpointUri.getPort()) diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java index 0e3e015a04636..ad01e8767b793 100644 --- a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java +++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java @@ -39,6 +39,7 @@ import java.sql.Statement; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Random; @@ -46,6 +47,7 @@ import java.util.concurrent.CountDownLatch; import org.apache.arrow.driver.jdbc.utils.CoreMockedSqlProducers; +import org.apache.arrow.driver.jdbc.utils.FallbackFlightSqlProducer; import org.apache.arrow.driver.jdbc.utils.PartitionedFlightSqlProducer; import org.apache.arrow.flight.FlightEndpoint; import org.apache.arrow.flight.FlightProducer; @@ -55,6 +57,7 @@ import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; @@ -63,6 +66,7 @@ import org.junit.ClassRule; import org.junit.Rule; import org.junit.Test; +import org.junit.jupiter.api.Assertions; import org.junit.rules.ErrorCollector; import com.google.common.collect.ImmutableSet; @@ -455,6 +459,69 @@ allocator, forGrpcInsecure("localhost", 0), rootProducer) } } + @Test + public void testFallbackFlightServer() throws Exception { + final Schema schema = new Schema( + Collections.singletonList(Field.nullable("int_column", Types.MinorType.INT.getType()))); + try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + VectorSchemaRoot resultData = VectorSchemaRoot.create(schema, allocator)) { + resultData.setRowCount(1); + ((IntVector) resultData.getVector(0)).set(0, 1); + + try (final FallbackFlightSqlProducer rootProducer = new FallbackFlightSqlProducer(resultData); + FlightServer rootServer = FlightServer.builder( + allocator, forGrpcInsecure("localhost", 0), rootProducer) + .build() + .start(); + Connection newConnection = DriverManager.getConnection(String.format( + "jdbc:arrow-flight-sql://%s:%d/?useEncryption=false", + rootServer.getLocation().getUri().getHost(), rootServer.getPort())); + Statement newStatement = newConnection.createStatement(); + ResultSet result = newStatement.executeQuery("fallback")) { + List actualData = new ArrayList<>(); + while (result.next()) { + actualData.add(result.getInt(1)); + } + + // Assert + assertEquals(resultData.getRowCount(), actualData.size()); + assertTrue(actualData.contains(((IntVector) resultData.getVector(0)).get(0))); + } + } + } + + @Test + public void testFallbackSecondFlightServer() throws Exception { + final Schema schema = new Schema( + Collections.singletonList(Field.nullable("int_column", Types.MinorType.INT.getType()))); + try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + VectorSchemaRoot resultData = VectorSchemaRoot.create(schema, allocator)) { + resultData.setRowCount(1); + ((IntVector) resultData.getVector(0)).set(0, 1); + + try (final FallbackFlightSqlProducer rootProducer = new FallbackFlightSqlProducer(resultData); + FlightServer rootServer = FlightServer.builder( + allocator, forGrpcInsecure("localhost", 0), rootProducer) + .build() + .start(); + Connection newConnection = DriverManager.getConnection(String.format( + "jdbc:arrow-flight-sql://%s:%d/?useEncryption=false", + rootServer.getLocation().getUri().getHost(), rootServer.getPort())); + Statement newStatement = newConnection.createStatement()) { + + // TODO(https://github.com/apache/arrow/issues/38573) + // XXX: we could try to assert more structure but then we'd have to hardcode + // a particular exception chain which may be fragile + Assertions.assertThrows(SQLException.class, () -> { + try (ResultSet result = newStatement.executeQuery("fallback with error")) { + // Empty body + } + }); + + } + } + } + @Test public void testShouldRunSelectQueryWithEmptyVectorsEmbedded() throws Exception { try (Statement statement = connection.createStatement(); diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FallbackFlightSqlProducer.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FallbackFlightSqlProducer.java new file mode 100644 index 0000000000000..2257220a4c845 --- /dev/null +++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FallbackFlightSqlProducer.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import java.util.Collections; +import java.util.List; + +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.sql.BasicFlightSqlProducer; +import org.apache.arrow.flight.sql.impl.FlightSql; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.Schema; + +import com.google.protobuf.Any; +import com.google.protobuf.ByteString; +import com.google.protobuf.Message; + +public class FallbackFlightSqlProducer extends BasicFlightSqlProducer { + private final VectorSchemaRoot data; + + public FallbackFlightSqlProducer(VectorSchemaRoot resultData) { + this.data = resultData; + } + + @Override + protected List determineEndpoints( + T request, FlightDescriptor flightDescriptor, Schema schema) { + return Collections.emptyList(); + } + + @Override + public void createPreparedStatement(FlightSql.ActionCreatePreparedStatementRequest request, + CallContext context, StreamListener listener) { + final FlightSql.ActionCreatePreparedStatementResult.Builder resultBuilder = + FlightSql.ActionCreatePreparedStatementResult.newBuilder() + .setPreparedStatementHandle(request.getQueryBytes()); + + final ByteString datasetSchemaBytes = ByteString.copyFrom(data.getSchema().serializeAsMessage()); + + resultBuilder.setDatasetSchema(datasetSchemaBytes); + listener.onNext(new Result(Any.pack(resultBuilder.build()).toByteArray())); + listener.onCompleted(); + } + + @Override + public FlightInfo getFlightInfoStatement( + FlightSql.CommandStatementQuery command, CallContext context, FlightDescriptor descriptor) { + return getFlightInfo(descriptor, command.getQuery()); + } + + @Override + public FlightInfo getFlightInfoPreparedStatement(FlightSql.CommandPreparedStatementQuery command, + CallContext context, FlightDescriptor descriptor) { + return getFlightInfo(descriptor, command.getPreparedStatementHandle().toStringUtf8()); + } + + @Override + public void getStreamStatement(FlightSql.TicketStatementQuery ticket, CallContext context, + ServerStreamListener listener) { + listener.start(data); + listener.putNext(); + listener.completed(); + } + + @Override + public void closePreparedStatement(FlightSql.ActionClosePreparedStatementRequest request, + CallContext context, StreamListener listener) { + listener.onCompleted(); + } + + private FlightInfo getFlightInfo(FlightDescriptor descriptor, String query) { + final List endpoints; + final Ticket ticket = new Ticket( + Any.pack(FlightSql.TicketStatementQuery.getDefaultInstance()).toByteArray()); + if (query.equals("fallback")) { + endpoints = Collections.singletonList(FlightEndpoint.builder(ticket, Location.reuseConnection()).build()); + } else if (query.equals("fallback with error")) { + endpoints = Collections.singletonList( + FlightEndpoint.builder(ticket, + Location.forGrpcInsecure("localhost", 9999), + Location.reuseConnection()).build()); + } else { + throw CallStatus.UNIMPLEMENTED.withDescription(query).toRuntimeException(); + } + return FlightInfo.builder(data.getSchema(), descriptor, endpoints).build(); + } +}