Skip to content

Commit

Permalink
Polishing #1209
Browse files Browse the repository at this point in the history
Simplify SSL tests. Refactor SSLEngine initialization into its own method.
  • Loading branch information
mp911de committed Jan 15, 2020
1 parent e4e4303 commit 120d57d
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 92 deletions.
66 changes: 37 additions & 29 deletions src/main/java/io/lettuce/core/SslConnectionBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import io.lettuce.core.internal.LettuceAssert;
import io.lettuce.core.protocol.AsyncCommand;
import io.lettuce.core.resource.ClientResources;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandler;
Expand Down Expand Up @@ -149,35 +150,7 @@ public SslChannelInitializer(Supplier<AsyncCommand<?, ?, ?>> pingCommandSupplier
@Override
protected void initChannel(Channel channel) throws Exception {

SSLParameters sslParams = new SSLParameters();

SslContextBuilder sslContextBuilder = SslContextBuilder.forClient().sslProvider(sslOptions.getSslProvider());
if (verifyPeer) {
sslParams.setEndpointIdentificationAlgorithm("HTTPS");
} else {
sslContextBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE);
}

if (sslOptions.getKeystore() != null) {
try (InputStream is = sslOptions.getKeystore().openStream()) {
sslContextBuilder.keyManager(createKeyManagerFactory(is,
sslOptions.getKeystorePassword().length == 0 ? null : sslOptions.getKeystorePassword()));
}
}

if (sslOptions.getTruststore() != null) {
try (InputStream is = sslOptions.getTruststore().openStream()) {
sslContextBuilder.trustManager(createTrustManagerFactory(is,
sslOptions.getTruststorePassword().length == 0 ? null : sslOptions.getTruststorePassword()));
}
}

SslContext sslContext = sslContextBuilder.build();

SSLEngine sslEngine = hostAndPort != null
? sslContext.newEngine(channel.alloc(), hostAndPort.getHostText(), hostAndPort.getPort())
: sslContext.newEngine(channel.alloc());
sslEngine.setSSLParameters(sslParams);
SSLEngine sslEngine = initializeSSLEngine(channel.alloc());

if (channel.pipeline().get("first") == null) {
channel.pipeline().addFirst("first", new ChannelDuplexHandler() {
Expand Down Expand Up @@ -273,6 +246,41 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E
clientResources.nettyCustomizer().afterChannelInitialized(channel);
}

private SSLEngine initializeSSLEngine(ByteBufAllocator alloc) throws IOException, GeneralSecurityException {

SSLParameters sslParams = new SSLParameters();

SslContextBuilder sslContextBuilder = SslContextBuilder.forClient().sslProvider(sslOptions.getSslProvider());
if (verifyPeer) {
sslParams.setEndpointIdentificationAlgorithm("HTTPS");
} else {
sslContextBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE);
}

if (sslOptions.getKeystore() != null) {
try (InputStream is = sslOptions.getKeystore().openStream()) {
sslContextBuilder.keyManager(createKeyManagerFactory(is,
sslOptions.getKeystorePassword().length == 0 ? null : sslOptions.getKeystorePassword()));
}
}

if (sslOptions.getTruststore() != null) {
try (InputStream is = sslOptions.getTruststore().openStream()) {
sslContextBuilder.trustManager(createTrustManagerFactory(is,
sslOptions.getTruststorePassword().length == 0 ? null : sslOptions.getTruststorePassword()));
}
}

SslContext sslContext = sslContextBuilder.build();

SSLEngine sslEngine = hostAndPort != null
? sslContext.newEngine(alloc, hostAndPort.getHostText(), hostAndPort.getPort())
: sslContext.newEngine(alloc);
sslEngine.setSSLParameters(sslParams);

return sslEngine;
}

@Override
public CompletableFuture<Boolean> channelInitialized() {
return initializedFuture;
Expand Down
91 changes: 28 additions & 63 deletions src/test/java/io/lettuce/core/SslIntegrationTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
*/
package io.lettuce.core;

import static io.lettuce.test.settings.TestSettings.host;
import static io.lettuce.test.settings.TestSettings.sslPort;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
Expand All @@ -24,9 +23,7 @@
import java.io.File;
import java.net.MalformedURLException;
import java.net.URL;
import java.security.GeneralSecurityException;
import java.time.Duration;
import java.util.Arrays;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;
Expand All @@ -37,16 +34,17 @@
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.test.util.ReflectionTestUtils;

import io.lettuce.core.api.StatefulRedisConnection;
import io.lettuce.core.api.sync.RedisCommands;
import io.lettuce.core.codec.StringCodec;
import io.lettuce.core.masterslave.MasterSlave;
import io.lettuce.core.pubsub.api.async.RedisPubSubAsyncCommands;
import io.lettuce.core.pubsub.api.sync.RedisPubSubCommands;
import io.lettuce.test.*;
import io.netty.handler.codec.DecoderException;
import io.lettuce.test.CanConnect;
import io.lettuce.test.Delay;
import io.lettuce.test.LettuceExtension;
import io.lettuce.test.Wait;
import io.lettuce.test.settings.TestSettings;
import io.netty.handler.ssl.OpenSsl;

/**
Expand All @@ -63,12 +61,12 @@ class SslIntegrationTests extends TestSupport {
private static final File TRUSTSTORE_FILE = new File(TRUSTSTORE);
private static final int MASTER_SLAVE_BASE_PORT_OFFSET = 2000;

private static final RedisURI URI_NO_VERIFY = sslURIBuilder(0) //
.withVerifyPeer(false) //
private static final RedisURI URI_VERIFY = sslURIBuilder(0) //
.withVerifyPeer(true) //
.build();

private static final RedisURI URI_VERIFY = sslURIBuilder(1) //
.withVerifyPeer(true) //
private static final RedisURI URI_NO_VERIFY = sslURIBuilder(1) //
.withVerifyPeer(false) //
.build();

private static final RedisURI URI_CLIENT_CERT_AUTH = sslURIBuilder(2) //
Expand Down Expand Up @@ -97,7 +95,7 @@ class SslIntegrationTests extends TestSupport {
@BeforeAll
static void beforeClass() {

assumeTrue(CanConnect.to(host(), sslPort()), "Assume that stunnel runs on port 6443");
assumeTrue(CanConnect.to(TestSettings.host(), sslPort()), "Assume that stunnel runs on port 6443");
assertThat(TRUSTSTORE_FILE).exists();
}

Expand Down Expand Up @@ -230,16 +228,16 @@ void regularSslWithReconnect() {
@Test
void sslWithVerificationWillFail() {

RedisURI redisUri = RedisURI.create("rediss://" + host() + ":" + sslPort());
RedisURI redisUri = RedisURI.create("rediss://" + TestSettings.host() + ":" + sslPort());

assertThatThrownBy(() -> redisClient.connect(redisUri).sync()).isInstanceOf(RedisConnectionException.class);
}

@Test
void masterSlaveWithSsl() {

RedisCommands<String, String> connection = MasterSlave.connect(redisClient, StringCodec.UTF8,
MASTER_SLAVE_URIS_NO_VERIFY).sync();
RedisCommands<String, String> connection = MasterSlave
.connect(redisClient, StringCodec.UTF8, MASTER_SLAVE_URIS_NO_VERIFY).sync();
connection.set("key", "value");
assertThat(connection.get("key")).isEqualTo("value");
connection.getStatefulConnection().close();
Expand Down Expand Up @@ -278,8 +276,8 @@ void masterSlaveWithJdkSslUsingTruststoreUrlWithWrongPassword() throws Exception
.build();
setOptions(sslOptions);

assertThatThrownBy(() -> verifyMasterSlaveConnection(MASTER_SLAVE_URIS_VERIFY)).isInstanceOf(
RedisConnectionException.class);
assertThatThrownBy(() -> verifyMasterSlaveConnection(MASTER_SLAVE_URIS_VERIFY))
.isInstanceOf(RedisConnectionException.class);
}

@Test
Expand All @@ -290,8 +288,8 @@ void masterSlaveWithJdkSslFailsWithWrongTruststore() {
.build();
setOptions(sslOptions);

assertThatThrownBy(() -> verifyMasterSlaveConnection(MASTER_SLAVE_URIS_VERIFY)).isInstanceOf(
RedisConnectionException.class);
assertThatThrownBy(() -> verifyMasterSlaveConnection(MASTER_SLAVE_URIS_VERIFY))
.isInstanceOf(RedisConnectionException.class);
}

@Test
Expand All @@ -304,8 +302,8 @@ void masterSlavePingBeforeActivate() {

@Test
void masterSlaveSslWithReconnect() {
RedisCommands<String, String> connection = MasterSlave.connect(redisClient, StringCodec.UTF8,
MASTER_SLAVE_URIS_NO_VERIFY).sync();
RedisCommands<String, String> connection = MasterSlave
.connect(redisClient, StringCodec.UTF8, MASTER_SLAVE_URIS_NO_VERIFY).sync();
connection.quit();
Delay.delay(Duration.ofMillis(200));
assertThat(connection.ping()).isEqualTo("PONG");
Expand All @@ -314,8 +312,8 @@ void masterSlaveSslWithReconnect() {

@Test
void masterSlaveSslWithVerificationWillFail() {
assertThatThrownBy(() -> MasterSlave.connect(redisClient, StringCodec.UTF8, MASTER_SLAVE_URIS_VERIFY)).isInstanceOf(
RedisConnectionException.class);
assertThatThrownBy(() -> MasterSlave.connect(redisClient, StringCodec.UTF8, MASTER_SLAVE_URIS_VERIFY))
.isInstanceOf(RedisConnectionException.class);
}

@Test
Expand All @@ -339,8 +337,8 @@ void masterSlaveSslWithAllInvalidHostsWillFail() {
.build();
setOptions(sslOptions);

assertThatThrownBy(() -> verifyMasterSlaveConnection(MASTER_SLAVE_URIS_WITH_ALL_INVALID)).isInstanceOf(
RedisConnectionException.class);
assertThatThrownBy(() -> verifyMasterSlaveConnection(MASTER_SLAVE_URIS_WITH_ALL_INVALID))
.isInstanceOf(RedisConnectionException.class);
}

@Test
Expand All @@ -365,50 +363,16 @@ void pubSubSsl() {
connection2.getStatefulConnection().close();
}

@Test
void pubSubSslAndBreakConnection() {

RedisURI redisURI = RedisURI.Builder.redis(host(), sslPort()).withSsl(true).withVerifyPeer(false).build();
redisClient.setOptions(ClientOptions.builder().suspendReconnectOnProtocolFailure(true).build());

RedisPubSubAsyncCommands<String, String> connection = redisClient.connectPubSub(redisURI).async();
RedisPubSubCommands<String, String> connection2 = redisClient.connectPubSub(redisURI).sync();

redisURI.setVerifyPeer(true);
connection.subscribe("c1");
connection.subscribe("c2");

Wait.untilTrue(() -> connection2.pubsubChannels().containsAll(Arrays.asList("c1", "c2"))).waitOrTimeout();

Futures.await(connection.quit());

List<String> future = connection2.pubsubChannels();
assertThat(future).doesNotContain("c1", "c2");

RedisChannelWriter channelWriter = ConnectionTestUtil.getChannelWriter(connection.getStatefulConnection());
Wait.untilNotEquals(null, () -> ReflectionTestUtils.getField(channelWriter, "connectionError")).waitOrTimeout();

RedisFuture<Void> defectFuture = connection.subscribe("foo");

assertThatThrownBy(() -> Futures.await(defectFuture)).hasCauseInstanceOf(DecoderException.class)
.hasRootCauseInstanceOf(GeneralSecurityException.class);

assertThat(defectFuture.toCompletableFuture()).isDone();

connection.getStatefulConnection().close();
connection2.getStatefulConnection().close();
}

private static RedisURI.Builder sslURIBuilder(int portOffset) {
return RedisURI.Builder.redis(host(), sslPort(portOffset)).withSsl(true);
return RedisURI.Builder.redis(TestSettings.host(), sslPort(portOffset)).withSsl(true);
}

private static List<RedisURI> sslUris(IntStream masterSlaveOffsets,
Function<RedisURI.Builder, RedisURI.Builder> builderCustomizer) {

return masterSlaveOffsets.map(it -> it + MASTER_SLAVE_BASE_PORT_OFFSET)
.mapToObj(offset -> RedisURI.Builder.redis(host(), sslPort(offset)).withSsl(true)).map(builderCustomizer)
.map(RedisURI.Builder::build).collect(Collectors.toList());
.mapToObj(offset -> RedisURI.Builder.redis(TestSettings.host(), sslPort(offset)).withSsl(true))
.map(builderCustomizer).map(RedisURI.Builder::build).collect(Collectors.toList());
}

private URL truststoreURL() throws MalformedURLException {
Expand All @@ -429,7 +393,8 @@ private void verifyConnection(RedisURI redisUri) {

private void verifyMasterSlaveConnection(List<RedisURI> redisUris) {

try (StatefulRedisConnection<String, String> connection = MasterSlave.connect(redisClient, StringCodec.UTF8, redisUris)) {
try (StatefulRedisConnection<String, String> connection = MasterSlave.connect(redisClient, StringCodec.UTF8,
redisUris)) {
connection.sync().ping();
}
}
Expand Down

0 comments on commit 120d57d

Please sign in to comment.