Skip to content

Commit

Permalink
Add ability to override SSL hostname checking
Browse files Browse the repository at this point in the history
  • Loading branch information
lidavidm committed Jun 20, 2019
1 parent 961927a commit 581fc75
Show file tree
Hide file tree
Showing 12 changed files with 259 additions and 11 deletions.
5 changes: 5 additions & 0 deletions cpp/src/arrow/flight/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,11 @@ class FlightClient::FlightClientImpl {
args.SetInt(GRPC_ARG_INITIAL_RECONNECT_BACKOFF_MS, 100);
// Receive messages of any size
args.SetMaxReceiveMessageSize(-1);

if (options.override_hostname != "") {
args.SetSslTargetNameOverride(options.override_hostname);
}

stub_ = pb::FlightService::NewStub(
grpc::CreateCustomChannel(grpc_uri.str(), creds, args));
return Status::OK();
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/arrow/flight/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@ class ARROW_FLIGHT_EXPORT FlightCallOptions {

class ARROW_FLIGHT_EXPORT FlightClientOptions {
public:
/// \brief Root certificates to use for validating server
/// certificates.
std::string tls_root_certs;
/// \brief Override the hostname checked by TLS. Use with caution.
std::string override_hostname;
};

/// \brief Client class for Arrow Flight RPC services (gRPC-based).
Expand Down
18 changes: 18 additions & 0 deletions cpp/src/arrow/flight/flight-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -675,5 +675,23 @@ TEST_F(TestTls, DoAction) {
ASSERT_EQ(result->body->ToString(), "Hello, world!");
}

TEST_F(TestTls, OverrideHostname) {
std::unique_ptr<FlightClient> client;
auto client_options = FlightClientOptions();
client_options.override_hostname = "fakehostname";
CertKeyPair root_cert;
ASSERT_OK(ExampleTlsCertificateRoot(&root_cert));
client_options.tls_root_certs = root_cert.pem_cert;
ASSERT_OK(FlightClient::Connect(server_->location(), client_options, &client));

FlightCallOptions options;
options.timeout = TimeoutDuration{5.0};
Action action;
action.type = "test";
action.body = Buffer::FromString("");
std::unique_ptr<ResultStream> results;
ASSERT_RAISES(IOError, client->DoAction(options, action, &results));
}

} // namespace flight
} // namespace arrow
15 changes: 9 additions & 6 deletions java/flight/pom.xml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
<?xml version="1.0"?>
<!-- Copyright (C) 2017-2018 Dremio Corporation Licensed under the Apache
License, Version 2.0 (the "License"); you may not use this file except in
compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
OR CONDITIONS OF ANY KIND, either express or implied. See the License for
<!-- Copyright (C) 2017-2018 Dremio Corporation Licensed under the Apache
License, Version 2.0 (the "License"); you may not use this file except in
compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
OR CONDITIONS OF ANY KIND, either express or implied. See the License for
the specific language governing permissions and limitations under the License. -->
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
Expand Down Expand Up @@ -137,6 +137,9 @@
<artifactId>maven-surefire-plugin</artifactId>
<configuration>
<enableAssertions>false</enableAssertions>
<systemPropertyVariables>
<arrow.test.dataRoot>${project.basedir}/../../testing/data</arrow.test.dataRoot>
</systemPropertyVariables>
</configuration>
</plugin>
<plugin>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ public static final class Builder {
private InputStream trustedCertificates = null;
private InputStream clientCertificate = null;
private InputStream clientKey = null;
private String overrideHostname = null;

private Builder() {
}
Expand All @@ -371,6 +372,12 @@ public Builder useTls() {
return this;
}

/** Override the hostname checked for TLS. Use with caution in production. */
public Builder overrideHostname(final String hostname) {
this.overrideHostname = hostname;
return this;
}

/** Set the maximum inbound message size. */
public Builder maxInboundMessageSize(int maxSize) {
Preconditions.checkArgument(maxSize > 0);
Expand Down Expand Up @@ -461,6 +468,10 @@ public FlightClient build() {
throw new RuntimeException(e);
}
}

if (this.overrideHostname != null) {
builder.overrideAuthority(this.overrideHostname);
}
} else {
builder.usePlaintext();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,23 @@ public void awaitTermination() throws InterruptedException {
server.awaitTermination();
}

/** Request that the server shut down. */
public void shutdown() {
server.shutdown();
}

/**
* Wait for the server to shut down with a timeout.
* @return true if the server shut down successfully.
*/
public boolean awaitTermination(final long timeout, final TimeUnit unit) throws InterruptedException {
return server.awaitTermination(timeout, unit);
}

/** Shutdown the server, waits for up to 6 seconds for successful shutdown before returning. */
public void close() throws InterruptedException {
server.shutdown();
final boolean terminated = server.awaitTermination(3000, TimeUnit.MILLISECONDS);
shutdown();
final boolean terminated = awaitTermination(3000, TimeUnit.MILLISECONDS);
if (terminated) {
logger.debug("Server was terminated within 3s");
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,27 @@

package org.apache.arrow.flight;

import java.io.File;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Random;
import java.util.function.Function;

/**
* Utility methods and constants for testing flight servers.
*/
public class FlightTestUtil {

private static final Random RANDOM = new Random();

public static final String LOCALHOST = "localhost";
public static final String TEST_DATA_ENV_VAR = "ARROW_TEST_DATA";
public static final String TEST_DATA_PROPERTY = "arrow.test.dataRoot";

/**
* Returns a a FlightServer (actually anything that is startable)
Expand Down Expand Up @@ -62,6 +71,30 @@ public static <T> T getStartedServer(Function<Integer, T> newServerFromPort) thr
return server;
}

static Path getTestDataRoot() {
String path = System.getenv(TEST_DATA_ENV_VAR);
if (path == null) {
path = System.getProperty(TEST_DATA_PROPERTY);
}
return Paths.get(Objects.requireNonNull(path,
String.format("Could not find test data path. Set the environment variable %s or the JVM property %s.",
TEST_DATA_ENV_VAR, TEST_DATA_PROPERTY)));
}

static Path getFlightTestDataRoot() {
return getTestDataRoot().resolve("flight");
}

static Path exampleTlsRootCert() {
return getFlightTestDataRoot().resolve("root-ca.pem");
}

static List<CertKeyPair> exampleTlsCerts() {
final Path root = getFlightTestDataRoot();
return Arrays.asList(new CertKeyPair(root.resolve("cert0.pem").toFile(), root.resolve("cert0.pkcs1").toFile()),
new CertKeyPair(root.resolve("cert1.pem").toFile(), root.resolve("cert1.pkcs1").toFile()));
}

static boolean isEpollAvailable() {
try {
Class<?> epoll = Class.forName("io.netty.channel.epoll.Epoll");
Expand All @@ -84,6 +117,17 @@ static boolean isNativeTransportAvailable() {
return isEpollAvailable() || isKqueueAvailable();
}

public static class CertKeyPair {

public final File cert;
public final File key;

public CertKeyPair(File cert, File key) {
this.cert = cert;
this.key = key;
}
}

private FlightTestUtil() {
}
}
130 changes: 130 additions & 0 deletions java/flight/src/test/java/org/apache/arrow/flight/TestTls.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.arrow.flight;

import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Iterator;
import java.util.function.Consumer;

import org.apache.arrow.flight.FlightClient.Builder;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;

import org.junit.Assert;
import org.junit.Test;

/**
* Tests for TLS in Flight.
*/
public class TestTls {

/**
* Test a basic request over TLS.
*/
@Test
public void connectTls() {
test((builder) -> {
try (final InputStream roots = new FileInputStream(FlightTestUtil.exampleTlsRootCert().toFile());
final FlightClient client = builder.trustedCertificates(roots).build()) {
final Iterator<Result> responses = client.doAction(new Action("hello-world"));
final byte[] response = responses.next().getBody();
Assert.assertEquals("Hello, world!", new String(response, StandardCharsets.UTF_8));
Assert.assertFalse(responses.hasNext());
} catch (InterruptedException | IOException e) {
throw new RuntimeException(e);
}
});
}

/**
* Make sure that connections are rejected when the root certificate isn't trusted.
*/
@Test(expected = io.grpc.StatusRuntimeException.class)
public void rejectInvalidCert() {
test((builder) -> {
try (final FlightClient client = builder.build()) {
final Iterator<Result> responses = client.doAction(new Action("hello-world"));
responses.next().getBody();
Assert.fail("Call should have failed");
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
});
}

/**
* Make sure that connections are rejected when the hostname doesn't match.
*/
@Test(expected = io.grpc.StatusRuntimeException.class)
public void rejectHostname() {
test((builder) -> {
try (final InputStream roots = new FileInputStream(FlightTestUtil.exampleTlsRootCert().toFile());
final FlightClient client = builder.trustedCertificates(roots).overrideHostname("fakehostname")
.build()) {
final Iterator<Result> responses = client.doAction(new Action("hello-world"));
responses.next().getBody();
Assert.fail("Call should have failed");
} catch (InterruptedException | IOException e) {
throw new RuntimeException(e);
}
});
}


void test(Consumer<Builder> testFn) {
final FlightTestUtil.CertKeyPair certKey = FlightTestUtil.exampleTlsCerts().get(0);
try (
BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
Producer producer = new Producer();
FlightServer s =
FlightTestUtil.getStartedServer(
(port) -> {
try {
return FlightServer.builder(a, Location.forGrpcTls(FlightTestUtil.LOCALHOST, port), producer)
.useTls(certKey.cert, certKey.key)
.build();
} catch (IOException e) {
throw new RuntimeException(e);
}
})) {
final Builder builder = FlightClient.builder(a, Location.forGrpcTls(FlightTestUtil.LOCALHOST, s.getPort()));
testFn.accept(builder);
} catch (InterruptedException | IOException e) {
throw new RuntimeException(e);
}
}

static class Producer extends NoOpFlightProducer implements AutoCloseable {

@Override
public void doAction(CallContext context, Action action, StreamListener<Result> listener) {
if (action.getType().equals("hello-world")) {
listener.onNext(new Result("Hello, world!".getBytes(StandardCharsets.UTF_8)));
listener.onCompleted();
}
listener.onError(new UnsupportedOperationException("Invalid action " + action.getType()));
}

@Override
public void close() {
}
}
}
8 changes: 6 additions & 2 deletions python/pyarrow/_flight.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ cdef class FlightClient:
.format(self.__class__.__name__))

@staticmethod
def connect(location, tls_root_certs=None):
def connect(location, tls_root_certs=None, override_hostname=None):
"""
Connect to a Flight service on the given host and port.
Expand All @@ -428,8 +428,10 @@ cdef class FlightClient:
location : Location
location to connect to
tls_root_certs : bytes
tls_root_certs : bytes or None
PEM-encoded
unsafe_override_hostname : str or None
Override the hostname checked by TLS. Insecure, use with caution.
"""
cdef:
FlightClient result = FlightClient.__new__(FlightClient)
Expand All @@ -439,6 +441,8 @@ cdef class FlightClient:

if tls_root_certs:
c_options.tls_root_certs = tobytes(tls_root_certs)
if override_hostname:
c_options.override_hostname = tobytes(override_hostname)

with nogil:
check_status(CFlightClient.Connect(c_location, c_options,
Expand Down
1 change: 1 addition & 0 deletions python/pyarrow/includes/libarrow_flight.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
cdef cppclass CFlightClientOptions" arrow::flight::FlightClientOptions":
CFlightClientOptions()
c_string tls_root_certs
c_string override_hostname

cdef cppclass CFlightClient" arrow::flight::FlightClient":
@staticmethod
Expand Down
15 changes: 15 additions & 0 deletions python/pyarrow/tests/test_flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,3 +570,18 @@ def test_tls_do_get():
server_location, tls_root_certs=certs["root_cert"])
data = client.do_get(flight.Ticket(b'ints')).read_all()
assert data.equals(table)


def test_tls_override_hostname():
"""Check that incorrectly overriding the hostname fails."""
certs = example_tls_certs()

with flight_server(
ConstantFlightServer, tls_certificates=certs["certificates"],
connect_args=dict(tls_root_certs=certs["root_cert"]),
) as server_location:
client = flight.FlightClient.connect(
server_location, tls_root_certs=certs["root_cert"],
override_hostname="fakehostname")
with pytest.raises(pa.ArrowIOError):
client.do_get(flight.Ticket(b'ints'))
Loading

0 comments on commit 581fc75

Please sign in to comment.