diff --git a/src/main/java/io/lettuce/core/SslConnectionBuilder.java b/src/main/java/io/lettuce/core/SslConnectionBuilder.java index 215544a3d9..1d3e80186e 100644 --- a/src/main/java/io/lettuce/core/SslConnectionBuilder.java +++ b/src/main/java/io/lettuce/core/SslConnectionBuilder.java @@ -42,6 +42,7 @@ import io.lettuce.core.internal.LettuceAssert; import io.lettuce.core.protocol.AsyncCommand; import io.lettuce.core.resource.ClientResources; +import io.netty.buffer.ByteBufAllocator; import io.netty.channel.Channel; import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelHandler; @@ -149,35 +150,7 @@ public SslChannelInitializer(Supplier> pingCommandSupplier @Override protected void initChannel(Channel channel) throws Exception { - SSLParameters sslParams = new SSLParameters(); - - SslContextBuilder sslContextBuilder = SslContextBuilder.forClient().sslProvider(sslOptions.getSslProvider()); - if (verifyPeer) { - sslParams.setEndpointIdentificationAlgorithm("HTTPS"); - } else { - sslContextBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE); - } - - if (sslOptions.getKeystore() != null) { - try (InputStream is = sslOptions.getKeystore().openStream()) { - sslContextBuilder.keyManager(createKeyManagerFactory(is, - sslOptions.getKeystorePassword().length == 0 ? null : sslOptions.getKeystorePassword())); - } - } - - if (sslOptions.getTruststore() != null) { - try (InputStream is = sslOptions.getTruststore().openStream()) { - sslContextBuilder.trustManager(createTrustManagerFactory(is, - sslOptions.getTruststorePassword().length == 0 ? null : sslOptions.getTruststorePassword())); - } - } - - SslContext sslContext = sslContextBuilder.build(); - - SSLEngine sslEngine = hostAndPort != null - ? sslContext.newEngine(channel.alloc(), hostAndPort.getHostText(), hostAndPort.getPort()) - : sslContext.newEngine(channel.alloc()); - sslEngine.setSSLParameters(sslParams); + SSLEngine sslEngine = initializeSSLEngine(channel.alloc()); if (channel.pipeline().get("first") == null) { channel.pipeline().addFirst("first", new ChannelDuplexHandler() { @@ -273,6 +246,41 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E clientResources.nettyCustomizer().afterChannelInitialized(channel); } + private SSLEngine initializeSSLEngine(ByteBufAllocator alloc) throws IOException, GeneralSecurityException { + + SSLParameters sslParams = new SSLParameters(); + + SslContextBuilder sslContextBuilder = SslContextBuilder.forClient().sslProvider(sslOptions.getSslProvider()); + if (verifyPeer) { + sslParams.setEndpointIdentificationAlgorithm("HTTPS"); + } else { + sslContextBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE); + } + + if (sslOptions.getKeystore() != null) { + try (InputStream is = sslOptions.getKeystore().openStream()) { + sslContextBuilder.keyManager(createKeyManagerFactory(is, + sslOptions.getKeystorePassword().length == 0 ? null : sslOptions.getKeystorePassword())); + } + } + + if (sslOptions.getTruststore() != null) { + try (InputStream is = sslOptions.getTruststore().openStream()) { + sslContextBuilder.trustManager(createTrustManagerFactory(is, + sslOptions.getTruststorePassword().length == 0 ? null : sslOptions.getTruststorePassword())); + } + } + + SslContext sslContext = sslContextBuilder.build(); + + SSLEngine sslEngine = hostAndPort != null + ? sslContext.newEngine(alloc, hostAndPort.getHostText(), hostAndPort.getPort()) + : sslContext.newEngine(alloc); + sslEngine.setSSLParameters(sslParams); + + return sslEngine; + } + @Override public CompletableFuture channelInitialized() { return initializedFuture; diff --git a/src/test/java/io/lettuce/core/SslIntegrationTests.java b/src/test/java/io/lettuce/core/SslIntegrationTests.java index 2bfa59a552..e13b221604 100644 --- a/src/test/java/io/lettuce/core/SslIntegrationTests.java +++ b/src/test/java/io/lettuce/core/SslIntegrationTests.java @@ -15,7 +15,6 @@ */ package io.lettuce.core; -import static io.lettuce.test.settings.TestSettings.host; import static io.lettuce.test.settings.TestSettings.sslPort; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -24,9 +23,7 @@ import java.io.File; import java.net.MalformedURLException; import java.net.URL; -import java.security.GeneralSecurityException; import java.time.Duration; -import java.util.Arrays; import java.util.List; import java.util.function.Function; import java.util.stream.Collectors; @@ -37,16 +34,17 @@ import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.springframework.test.util.ReflectionTestUtils; import io.lettuce.core.api.StatefulRedisConnection; import io.lettuce.core.api.sync.RedisCommands; import io.lettuce.core.codec.StringCodec; import io.lettuce.core.masterslave.MasterSlave; -import io.lettuce.core.pubsub.api.async.RedisPubSubAsyncCommands; import io.lettuce.core.pubsub.api.sync.RedisPubSubCommands; -import io.lettuce.test.*; -import io.netty.handler.codec.DecoderException; +import io.lettuce.test.CanConnect; +import io.lettuce.test.Delay; +import io.lettuce.test.LettuceExtension; +import io.lettuce.test.Wait; +import io.lettuce.test.settings.TestSettings; import io.netty.handler.ssl.OpenSsl; /** @@ -63,12 +61,12 @@ class SslIntegrationTests extends TestSupport { private static final File TRUSTSTORE_FILE = new File(TRUSTSTORE); private static final int MASTER_SLAVE_BASE_PORT_OFFSET = 2000; - private static final RedisURI URI_NO_VERIFY = sslURIBuilder(0) // - .withVerifyPeer(false) // + private static final RedisURI URI_VERIFY = sslURIBuilder(0) // + .withVerifyPeer(true) // .build(); - private static final RedisURI URI_VERIFY = sslURIBuilder(1) // - .withVerifyPeer(true) // + private static final RedisURI URI_NO_VERIFY = sslURIBuilder(1) // + .withVerifyPeer(false) // .build(); private static final RedisURI URI_CLIENT_CERT_AUTH = sslURIBuilder(2) // @@ -97,7 +95,7 @@ class SslIntegrationTests extends TestSupport { @BeforeAll static void beforeClass() { - assumeTrue(CanConnect.to(host(), sslPort()), "Assume that stunnel runs on port 6443"); + assumeTrue(CanConnect.to(TestSettings.host(), sslPort()), "Assume that stunnel runs on port 6443"); assertThat(TRUSTSTORE_FILE).exists(); } @@ -230,7 +228,7 @@ void regularSslWithReconnect() { @Test void sslWithVerificationWillFail() { - RedisURI redisUri = RedisURI.create("rediss://" + host() + ":" + sslPort()); + RedisURI redisUri = RedisURI.create("rediss://" + TestSettings.host() + ":" + sslPort()); assertThatThrownBy(() -> redisClient.connect(redisUri).sync()).isInstanceOf(RedisConnectionException.class); } @@ -238,8 +236,8 @@ void sslWithVerificationWillFail() { @Test void masterSlaveWithSsl() { - RedisCommands connection = MasterSlave.connect(redisClient, StringCodec.UTF8, - MASTER_SLAVE_URIS_NO_VERIFY).sync(); + RedisCommands connection = MasterSlave + .connect(redisClient, StringCodec.UTF8, MASTER_SLAVE_URIS_NO_VERIFY).sync(); connection.set("key", "value"); assertThat(connection.get("key")).isEqualTo("value"); connection.getStatefulConnection().close(); @@ -278,8 +276,8 @@ void masterSlaveWithJdkSslUsingTruststoreUrlWithWrongPassword() throws Exception .build(); setOptions(sslOptions); - assertThatThrownBy(() -> verifyMasterSlaveConnection(MASTER_SLAVE_URIS_VERIFY)).isInstanceOf( - RedisConnectionException.class); + assertThatThrownBy(() -> verifyMasterSlaveConnection(MASTER_SLAVE_URIS_VERIFY)) + .isInstanceOf(RedisConnectionException.class); } @Test @@ -290,8 +288,8 @@ void masterSlaveWithJdkSslFailsWithWrongTruststore() { .build(); setOptions(sslOptions); - assertThatThrownBy(() -> verifyMasterSlaveConnection(MASTER_SLAVE_URIS_VERIFY)).isInstanceOf( - RedisConnectionException.class); + assertThatThrownBy(() -> verifyMasterSlaveConnection(MASTER_SLAVE_URIS_VERIFY)) + .isInstanceOf(RedisConnectionException.class); } @Test @@ -304,8 +302,8 @@ void masterSlavePingBeforeActivate() { @Test void masterSlaveSslWithReconnect() { - RedisCommands connection = MasterSlave.connect(redisClient, StringCodec.UTF8, - MASTER_SLAVE_URIS_NO_VERIFY).sync(); + RedisCommands connection = MasterSlave + .connect(redisClient, StringCodec.UTF8, MASTER_SLAVE_URIS_NO_VERIFY).sync(); connection.quit(); Delay.delay(Duration.ofMillis(200)); assertThat(connection.ping()).isEqualTo("PONG"); @@ -314,8 +312,8 @@ void masterSlaveSslWithReconnect() { @Test void masterSlaveSslWithVerificationWillFail() { - assertThatThrownBy(() -> MasterSlave.connect(redisClient, StringCodec.UTF8, MASTER_SLAVE_URIS_VERIFY)).isInstanceOf( - RedisConnectionException.class); + assertThatThrownBy(() -> MasterSlave.connect(redisClient, StringCodec.UTF8, MASTER_SLAVE_URIS_VERIFY)) + .isInstanceOf(RedisConnectionException.class); } @Test @@ -339,8 +337,8 @@ void masterSlaveSslWithAllInvalidHostsWillFail() { .build(); setOptions(sslOptions); - assertThatThrownBy(() -> verifyMasterSlaveConnection(MASTER_SLAVE_URIS_WITH_ALL_INVALID)).isInstanceOf( - RedisConnectionException.class); + assertThatThrownBy(() -> verifyMasterSlaveConnection(MASTER_SLAVE_URIS_WITH_ALL_INVALID)) + .isInstanceOf(RedisConnectionException.class); } @Test @@ -365,50 +363,16 @@ void pubSubSsl() { connection2.getStatefulConnection().close(); } - @Test - void pubSubSslAndBreakConnection() { - - RedisURI redisURI = RedisURI.Builder.redis(host(), sslPort()).withSsl(true).withVerifyPeer(false).build(); - redisClient.setOptions(ClientOptions.builder().suspendReconnectOnProtocolFailure(true).build()); - - RedisPubSubAsyncCommands connection = redisClient.connectPubSub(redisURI).async(); - RedisPubSubCommands connection2 = redisClient.connectPubSub(redisURI).sync(); - - redisURI.setVerifyPeer(true); - connection.subscribe("c1"); - connection.subscribe("c2"); - - Wait.untilTrue(() -> connection2.pubsubChannels().containsAll(Arrays.asList("c1", "c2"))).waitOrTimeout(); - - Futures.await(connection.quit()); - - List future = connection2.pubsubChannels(); - assertThat(future).doesNotContain("c1", "c2"); - - RedisChannelWriter channelWriter = ConnectionTestUtil.getChannelWriter(connection.getStatefulConnection()); - Wait.untilNotEquals(null, () -> ReflectionTestUtils.getField(channelWriter, "connectionError")).waitOrTimeout(); - - RedisFuture defectFuture = connection.subscribe("foo"); - - assertThatThrownBy(() -> Futures.await(defectFuture)).hasCauseInstanceOf(DecoderException.class) - .hasRootCauseInstanceOf(GeneralSecurityException.class); - - assertThat(defectFuture.toCompletableFuture()).isDone(); - - connection.getStatefulConnection().close(); - connection2.getStatefulConnection().close(); - } - private static RedisURI.Builder sslURIBuilder(int portOffset) { - return RedisURI.Builder.redis(host(), sslPort(portOffset)).withSsl(true); + return RedisURI.Builder.redis(TestSettings.host(), sslPort(portOffset)).withSsl(true); } private static List sslUris(IntStream masterSlaveOffsets, Function builderCustomizer) { return masterSlaveOffsets.map(it -> it + MASTER_SLAVE_BASE_PORT_OFFSET) - .mapToObj(offset -> RedisURI.Builder.redis(host(), sslPort(offset)).withSsl(true)).map(builderCustomizer) - .map(RedisURI.Builder::build).collect(Collectors.toList()); + .mapToObj(offset -> RedisURI.Builder.redis(TestSettings.host(), sslPort(offset)).withSsl(true)) + .map(builderCustomizer).map(RedisURI.Builder::build).collect(Collectors.toList()); } private URL truststoreURL() throws MalformedURLException { @@ -429,7 +393,8 @@ private void verifyConnection(RedisURI redisUri) { private void verifyMasterSlaveConnection(List redisUris) { - try (StatefulRedisConnection connection = MasterSlave.connect(redisClient, StringCodec.UTF8, redisUris)) { + try (StatefulRedisConnection connection = MasterSlave.connect(redisClient, StringCodec.UTF8, + redisUris)) { connection.sync().ping(); } }