diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 2b7c69919763e..1926928c6438e 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -259,6 +259,11 @@ class FlightClient::FlightClientImpl { args.SetInt(GRPC_ARG_INITIAL_RECONNECT_BACKOFF_MS, 100); // Receive messages of any size args.SetMaxReceiveMessageSize(-1); + + if (options.override_hostname != "") { + args.SetSslTargetNameOverride(options.override_hostname); + } + stub_ = pb::FlightService::NewStub( grpc::CreateCustomChannel(grpc_uri.str(), creds, args)); return Status::OK(); diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index 689c9f8c5b501..b8a5d4f4b91dd 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -59,7 +59,11 @@ class ARROW_FLIGHT_EXPORT FlightCallOptions { class ARROW_FLIGHT_EXPORT FlightClientOptions { public: + /// \brief Root certificates to use for validating server + /// certificates. std::string tls_root_certs; + /// \brief Override the hostname checked by TLS. Use with caution. + std::string override_hostname; }; /// \brief Client class for Arrow Flight RPC services (gRPC-based). diff --git a/cpp/src/arrow/flight/flight-test.cc b/cpp/src/arrow/flight/flight-test.cc index b295878641523..3c0b67cd9927b 100644 --- a/cpp/src/arrow/flight/flight-test.cc +++ b/cpp/src/arrow/flight/flight-test.cc @@ -675,5 +675,23 @@ TEST_F(TestTls, DoAction) { ASSERT_EQ(result->body->ToString(), "Hello, world!"); } +TEST_F(TestTls, OverrideHostname) { + std::unique_ptr client; + auto client_options = FlightClientOptions(); + client_options.override_hostname = "fakehostname"; + CertKeyPair root_cert; + ASSERT_OK(ExampleTlsCertificateRoot(&root_cert)); + client_options.tls_root_certs = root_cert.pem_cert; + ASSERT_OK(FlightClient::Connect(server_->location(), client_options, &client)); + + FlightCallOptions options; + options.timeout = TimeoutDuration{5.0}; + Action action; + action.type = "test"; + action.body = Buffer::FromString(""); + std::unique_ptr results; + ASSERT_RAISES(IOError, client->DoAction(options, action, &results)); +} + } // namespace flight } // namespace arrow diff --git a/java/flight/pom.xml b/java/flight/pom.xml index 7d01a6e118e23..b03fbe6a88194 100644 --- a/java/flight/pom.xml +++ b/java/flight/pom.xml @@ -1,10 +1,10 @@ - 4.0.0 @@ -137,6 +137,9 @@ maven-surefire-plugin false + + ${project.basedir}/../../testing/data + diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java index c70e1fddca22a..37e4514db7d82 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java @@ -354,6 +354,7 @@ public static final class Builder { private InputStream trustedCertificates = null; private InputStream clientCertificate = null; private InputStream clientKey = null; + private String overrideHostname = null; private Builder() { } @@ -371,6 +372,12 @@ public Builder useTls() { return this; } + /** Override the hostname checked for TLS. Use with caution in production. */ + public Builder overrideHostname(final String hostname) { + this.overrideHostname = hostname; + return this; + } + /** Set the maximum inbound message size. */ public Builder maxInboundMessageSize(int maxSize) { Preconditions.checkArgument(maxSize > 0); @@ -461,6 +468,10 @@ public FlightClient build() { throw new RuntimeException(e); } } + + if (this.overrideHostname != null) { + builder.overrideAuthority(this.overrideHostname); + } } else { builder.usePlaintext(); } diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java index eaea0441b2c56..cd59a75cbbb97 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java @@ -72,10 +72,23 @@ public void awaitTermination() throws InterruptedException { server.awaitTermination(); } + /** Request that the server shut down. */ + public void shutdown() { + server.shutdown(); + } + + /** + * Wait for the server to shut down with a timeout. + * @return true if the server shut down successfully. + */ + public boolean awaitTermination(final long timeout, final TimeUnit unit) throws InterruptedException { + return server.awaitTermination(timeout, unit); + } + /** Shutdown the server, waits for up to 6 seconds for successful shutdown before returning. */ public void close() throws InterruptedException { - server.shutdown(); - final boolean terminated = server.awaitTermination(3000, TimeUnit.MILLISECONDS); + shutdown(); + final boolean terminated = awaitTermination(3000, TimeUnit.MILLISECONDS); if (terminated) { logger.debug("Server was terminated within 3s"); return; diff --git a/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java b/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java index f6b9e86780729..3cb09ef5cd9f8 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java @@ -17,8 +17,14 @@ package org.apache.arrow.flight; +import java.io.File; import java.io.IOException; import java.lang.reflect.InvocationTargetException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; import java.util.Random; import java.util.function.Function; @@ -26,9 +32,12 @@ * Utility methods and constants for testing flight servers. */ public class FlightTestUtil { + private static final Random RANDOM = new Random(); public static final String LOCALHOST = "localhost"; + public static final String TEST_DATA_ENV_VAR = "ARROW_TEST_DATA"; + public static final String TEST_DATA_PROPERTY = "arrow.test.dataRoot"; /** * Returns a a FlightServer (actually anything that is startable) @@ -62,6 +71,30 @@ public static T getStartedServer(Function newServerFromPort) thr return server; } + static Path getTestDataRoot() { + String path = System.getenv(TEST_DATA_ENV_VAR); + if (path == null) { + path = System.getProperty(TEST_DATA_PROPERTY); + } + return Paths.get(Objects.requireNonNull(path, + String.format("Could not find test data path. Set the environment variable %s or the JVM property %s.", + TEST_DATA_ENV_VAR, TEST_DATA_PROPERTY))); + } + + static Path getFlightTestDataRoot() { + return getTestDataRoot().resolve("flight"); + } + + static Path exampleTlsRootCert() { + return getFlightTestDataRoot().resolve("root-ca.pem"); + } + + static List exampleTlsCerts() { + final Path root = getFlightTestDataRoot(); + return Arrays.asList(new CertKeyPair(root.resolve("cert0.pem").toFile(), root.resolve("cert0.pkcs1").toFile()), + new CertKeyPair(root.resolve("cert1.pem").toFile(), root.resolve("cert1.pkcs1").toFile())); + } + static boolean isEpollAvailable() { try { Class epoll = Class.forName("io.netty.channel.epoll.Epoll"); @@ -84,6 +117,17 @@ static boolean isNativeTransportAvailable() { return isEpollAvailable() || isKqueueAvailable(); } + public static class CertKeyPair { + + public final File cert; + public final File key; + + public CertKeyPair(File cert, File key) { + this.cert = cert; + this.key = key; + } + } + private FlightTestUtil() { } } diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestTls.java b/java/flight/src/test/java/org/apache/arrow/flight/TestTls.java new file mode 100644 index 0000000000000..c22304d56470f --- /dev/null +++ b/java/flight/src/test/java/org/apache/arrow/flight/TestTls.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.Iterator; +import java.util.function.Consumer; + +import org.apache.arrow.flight.FlightClient.Builder; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; + +import org.junit.Assert; +import org.junit.Test; + +/** + * Tests for TLS in Flight. + */ +public class TestTls { + + /** + * Test a basic request over TLS. + */ + @Test + public void connectTls() { + test((builder) -> { + try (final InputStream roots = new FileInputStream(FlightTestUtil.exampleTlsRootCert().toFile()); + final FlightClient client = builder.trustedCertificates(roots).build()) { + final Iterator responses = client.doAction(new Action("hello-world")); + final byte[] response = responses.next().getBody(); + Assert.assertEquals("Hello, world!", new String(response, StandardCharsets.UTF_8)); + Assert.assertFalse(responses.hasNext()); + } catch (InterruptedException | IOException e) { + throw new RuntimeException(e); + } + }); + } + + /** + * Make sure that connections are rejected when the root certificate isn't trusted. + */ + @Test(expected = io.grpc.StatusRuntimeException.class) + public void rejectInvalidCert() { + test((builder) -> { + try (final FlightClient client = builder.build()) { + final Iterator responses = client.doAction(new Action("hello-world")); + responses.next().getBody(); + Assert.fail("Call should have failed"); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + } + + /** + * Make sure that connections are rejected when the hostname doesn't match. + */ + @Test(expected = io.grpc.StatusRuntimeException.class) + public void rejectHostname() { + test((builder) -> { + try (final InputStream roots = new FileInputStream(FlightTestUtil.exampleTlsRootCert().toFile()); + final FlightClient client = builder.trustedCertificates(roots).overrideHostname("fakehostname") + .build()) { + final Iterator responses = client.doAction(new Action("hello-world")); + responses.next().getBody(); + Assert.fail("Call should have failed"); + } catch (InterruptedException | IOException e) { + throw new RuntimeException(e); + } + }); + } + + + void test(Consumer testFn) { + final FlightTestUtil.CertKeyPair certKey = FlightTestUtil.exampleTlsCerts().get(0); + try ( + BufferAllocator a = new RootAllocator(Long.MAX_VALUE); + Producer producer = new Producer(); + FlightServer s = + FlightTestUtil.getStartedServer( + (port) -> { + try { + return FlightServer.builder(a, Location.forGrpcTls(FlightTestUtil.LOCALHOST, port), producer) + .useTls(certKey.cert, certKey.key) + .build(); + } catch (IOException e) { + throw new RuntimeException(e); + } + })) { + final Builder builder = FlightClient.builder(a, Location.forGrpcTls(FlightTestUtil.LOCALHOST, s.getPort())); + testFn.accept(builder); + } catch (InterruptedException | IOException e) { + throw new RuntimeException(e); + } + } + + static class Producer extends NoOpFlightProducer implements AutoCloseable { + + @Override + public void doAction(CallContext context, Action action, StreamListener listener) { + if (action.getType().equals("hello-world")) { + listener.onNext(new Result("Hello, world!".getBytes(StandardCharsets.UTF_8))); + listener.onCompleted(); + } + listener.onError(new UnsupportedOperationException("Invalid action " + action.getType())); + } + + @Override + public void close() { + } + } +} diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index c916e6bcf56ca..7ca83a94994c7 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -419,7 +419,7 @@ cdef class FlightClient: .format(self.__class__.__name__)) @staticmethod - def connect(location, tls_root_certs=None): + def connect(location, tls_root_certs=None, override_hostname=None): """ Connect to a Flight service on the given host and port. @@ -428,8 +428,10 @@ cdef class FlightClient: location : Location location to connect to - tls_root_certs : bytes + tls_root_certs : bytes or None PEM-encoded + unsafe_override_hostname : str or None + Override the hostname checked by TLS. Insecure, use with caution. """ cdef: FlightClient result = FlightClient.__new__(FlightClient) @@ -439,6 +441,8 @@ cdef class FlightClient: if tls_root_certs: c_options.tls_root_certs = tobytes(tls_root_certs) + if override_hostname: + c_options.override_hostname = tobytes(override_hostname) with nogil: check_status(CFlightClient.Connect(c_location, c_options, diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd index 14d1ed163d186..61e9571995d4f 100644 --- a/python/pyarrow/includes/libarrow_flight.pxd +++ b/python/pyarrow/includes/libarrow_flight.pxd @@ -170,6 +170,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: cdef cppclass CFlightClientOptions" arrow::flight::FlightClientOptions": CFlightClientOptions() c_string tls_root_certs + c_string override_hostname cdef cppclass CFlightClient" arrow::flight::FlightClient": @staticmethod diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py index f4c9cc12beee9..3088a7a86f110 100644 --- a/python/pyarrow/tests/test_flight.py +++ b/python/pyarrow/tests/test_flight.py @@ -570,3 +570,18 @@ def test_tls_do_get(): server_location, tls_root_certs=certs["root_cert"]) data = client.do_get(flight.Ticket(b'ints')).read_all() assert data.equals(table) + + +def test_tls_override_hostname(): + """Check that incorrectly overriding the hostname fails.""" + certs = example_tls_certs() + + with flight_server( + ConstantFlightServer, tls_certificates=certs["certificates"], + connect_args=dict(tls_root_certs=certs["root_cert"]), + ) as server_location: + client = flight.FlightClient.connect( + server_location, tls_root_certs=certs["root_cert"], + override_hostname="fakehostname") + with pytest.raises(pa.ArrowIOError): + client.do_get(flight.Ticket(b'ints')) diff --git a/testing b/testing index 12f9dbd2a37ee..a674dac190c5f 160000 --- a/testing +++ b/testing @@ -1 +1 @@ -Subproject commit 12f9dbd2a37eea6fa370e108a1d797ee1167724a +Subproject commit a674dac190c5fc626964c9b611c67552fa2e530d