Skip to content
This repository has been archived by the owner on Nov 14, 2024. It is now read-only.

Better logging for endpoint verification #6314

Merged
merged 7 commits into from
Oct 25, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.github.benmanes.caffeine.cache.Caffeine;
import com.github.benmanes.caffeine.cache.LoadingCache;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableSet;
import com.palantir.atlasdb.cassandra.CassandraCredentialsConfig;
import com.palantir.atlasdb.cassandra.CassandraKeyValueServiceConfig;
import com.palantir.atlasdb.keyvalue.cassandra.ImmutableCassandraClientConfig.SocketTimeoutMillisBuildStage;
Expand All @@ -44,7 +45,7 @@
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;
import java.util.Set;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import org.apache.cassandra.thrift.AuthenticationRequest;
Expand Down Expand Up @@ -179,8 +180,8 @@ private static Cassandra.Client getRawClient(
try {
SSLSocket socket = (SSLSocket) sslSocketFactory.createSocket(
thriftSocket.getSocket(), addr.getHostString(), addr.getPort(), true);
verifyEndpoint(cassandraServer, socket, clientConfig.enableEndpointVerification());
thriftSocket = tSocketFactory.create(socket);
verifyEndpoint(cassandraServer, socket, clientConfig.enableEndpointVerification());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noticed this "bug" when revisiting things -- our initial reach-out will not include the thrift options we later set on the socket. It's probably OK for us to ignore this, but from a code "safety" point-of-view we should really just set the options for consistency reasons.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this makes sense, would recommend updating the PR description :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes, I should do that!

success = true;
} catch (IOException e) {
throw new TTransportException(e);
Expand Down Expand Up @@ -222,23 +223,37 @@ private static void login(Client client, CassandraCredentialsConfig config) thro
* This will check both ip address/hostname, and uses the IP address associated with the socket, rather
* that what has been provided. Hostname/ip address are both need to be checked, as historically we've
* connected to Cassandra directly using IP addresses, and therefore need to support such cases.
*
* Will only throw when throwOnFailure is true, even if the socket is closed during verification.
*/
@VisibleForTesting
static void verifyEndpoint(CassandraServer cassandraServer, SSLSocket socket, boolean throwOnFailure)
throws SafeSSLPeerUnverifiedException {
boolean endpointVerified = Stream.of(
socket.getInetAddress().getHostAddress(), cassandraServer.cassandraHostName())
.anyMatch(address -> hostnameVerifier.verify(address, socket.getSession()));

Set<String> endpointsToCheck = getEndpointsToCheck(cassandraServer, socket);
boolean endpointVerified =
endpointsToCheck.stream().anyMatch(address -> hostnameVerifier.verify(address, socket.getSession()));
if (socket.isClosed()) {
if (throwOnFailure) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah nice catch!

throw new SafeSSLPeerUnverifiedException(
"Unable to verify endpoints as socket is closed.",
SafeArg.of("endpoint", socket.getInetAddress()));
}
return;
}
if (!endpointVerified) {
log.warn("Endpoint verification failed for host.", SafeArg.of("cassandraServer", cassandraServer));
log.warn("Endpoint verification failed for host.", SafeArg.of("endpointsChecked", endpointsToCheck));
if (throwOnFailure) {
throw new SafeSSLPeerUnverifiedException(
"Endpoint verification failed for host.", SafeArg.of("cassandraServer", cassandraServer));
"Endpoint verification failed for host.", SafeArg.of("endpointsChecked", endpointsToCheck));
}
}
}

@VisibleForTesting
static Set<String> getEndpointsToCheck(CassandraServer cassandraServer, SSLSocket socket) {
return ImmutableSet.of(socket.getInetAddress().getHostAddress(), cassandraServer.cassandraHostName());
}

@Override
public boolean validateObject(PooledObject<CassandraClient> client) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,40 @@ public void verifyEndpointDoesNotThrowWhenHostnameNotPresentButIpIs() {
.doesNotThrowAnyException();
}

@Test
public void verifyEndpointThrowsWhenSocketIsClosed() {
SSLSocket sslSocket = createSSLSocket(DEFAULT_SERVER, DEFAULT_ADDRESS);
when(sslSocket.isClosed()).thenReturn(true);
assertThatThrownBy(() -> CassandraClientFactory.verifyEndpoint(DEFAULT_SERVER, sslSocket, true))
.isInstanceOf(SafeSSLPeerUnverifiedException.class);
}

@Test
public void verifyEndpointDoesNotThrowWhenSocketIsClosedAndThrowOnFailureIsFalse() {
SSLSocket sslSocket = createSSLSocket(DEFAULT_SERVER, DEFAULT_ADDRESS);
when(sslSocket.isClosed()).thenReturn(true);
assertThatCode(() -> CassandraClientFactory.verifyEndpoint(DEFAULT_SERVER, sslSocket, false))
.doesNotThrowAnyException();
}

@Test
public void getEndpointsToCheckDeduplicatesMatchingHostnameIp() {
CassandraServer cassandraServer =
CassandraServer.of(InetSocketAddress.createUnresolved(DEFAULT_ADDRESS.getHostAddress(), 4000));
SSLSocket sslSocket = createSSLSocket(cassandraServer, DEFAULT_ADDRESS);
assertThat(CassandraClientFactory.getEndpointsToCheck(cassandraServer, sslSocket))
.isNotEmpty()
.containsExactly(DEFAULT_ADDRESS.getHostAddress());
}

@Test
public void getEndpointsToCheckPerformsNoDeduplicationWhenHostnameIpDiffer() {
SSLSocket sslSocket = createSSLSocket(DEFAULT_SERVER, DEFAULT_ADDRESS);
assertThat(CassandraClientFactory.getEndpointsToCheck(DEFAULT_SERVER, sslSocket))
.isNotEmpty()
.containsExactlyInAnyOrder(DEFAULT_SERVER.cassandraHostName(), DEFAULT_ADDRESS.getHostAddress());
}

@SuppressWarnings("ReverseDnsLookup")
private static InetSocketAddress mockInetSocketAddress(String ipAddress) {
InetAddress inetAddress = mockInetAddress(ipAddress);
Expand Down