diff --git a/src/main/java/io/lettuce/core/cluster/ClusterDistributionChannelWriter.java b/src/main/java/io/lettuce/core/cluster/ClusterDistributionChannelWriter.java index 5d10fc7982..5c386cd3c4 100644 --- a/src/main/java/io/lettuce/core/cluster/ClusterDistributionChannelWriter.java +++ b/src/main/java/io/lettuce/core/cluster/ClusterDistributionChannelWriter.java @@ -46,6 +46,7 @@ import io.lettuce.core.output.StatusOutput; import io.lettuce.core.protocol.Command; import io.lettuce.core.protocol.CommandArgs; +import io.lettuce.core.protocol.CommandExpiryWriter; import io.lettuce.core.protocol.CommandKeyword; import io.lettuce.core.protocol.CommandType; import io.lettuce.core.protocol.ConnectionFacade; @@ -76,7 +77,7 @@ class ClusterDistributionChannelWriter implements RedisChannelWriter { private volatile Partitions partitions; - ClusterDistributionChannelWriter(ClientOptions clientOptions, RedisChannelWriter defaultWriter, + ClusterDistributionChannelWriter(RedisChannelWriter defaultWriter, ClientOptions clientOptions, ClusterEventListener clusterEventListener) { if (clientOptions instanceof ClusterClientOptions) { @@ -426,16 +427,29 @@ public CompletableFuture closeAsync() { } public void disconnectDefaultEndpoint() { + unwrapDefaultEndpoint().disconnect(); + } - DefaultEndpoint defaultEndpoint; + private DefaultEndpoint unwrapDefaultEndpoint() { - if (defaultWriter instanceof CommandListenerWriter) { - defaultEndpoint = (DefaultEndpoint) ((CommandListenerWriter) defaultWriter).getDelegate(); - } else { - defaultEndpoint = ((DefaultEndpoint) defaultWriter); + RedisChannelWriter writer = this.defaultWriter; + + while (!(writer instanceof DefaultEndpoint)) { + + if (writer instanceof CommandListenerWriter) { + writer = ((CommandListenerWriter) writer).getDelegate(); + continue; + } + + if (writer instanceof CommandExpiryWriter) { + writer = ((CommandExpiryWriter) writer).getDelegate(); + continue; + } + + throw new IllegalStateException(String.format("Cannot unwrap defaultWriter %s into DefaultEndpoint", writer)); } - defaultEndpoint.disconnect(); + return (DefaultEndpoint) writer; } @Override diff --git a/src/main/java/io/lettuce/core/cluster/RedisClusterClient.java b/src/main/java/io/lettuce/core/cluster/RedisClusterClient.java index 0c7ad31895..4099ebb80d 100644 --- a/src/main/java/io/lettuce/core/cluster/RedisClusterClient.java +++ b/src/main/java/io/lettuce/core/cluster/RedisClusterClient.java @@ -651,7 +651,7 @@ private CompletableFuture> connectCl writer = new CommandListenerWriter(writer, getCommandListeners()); } - ClusterDistributionChannelWriter clusterWriter = new ClusterDistributionChannelWriter(getClusterClientOptions(), writer, + ClusterDistributionChannelWriter clusterWriter = new ClusterDistributionChannelWriter(writer, getClusterClientOptions(), topologyRefreshScheduler); PooledClusterConnectionProvider pooledClusterConnectionProvider = new PooledClusterConnectionProvider<>(this, clusterWriter, codec, topologyRefreshScheduler); @@ -748,7 +748,7 @@ private CompletableFuture> con writer = new CommandListenerWriter(writer, getCommandListeners()); } - ClusterDistributionChannelWriter clusterWriter = new ClusterDistributionChannelWriter(getClusterClientOptions(), writer, + ClusterDistributionChannelWriter clusterWriter = new ClusterDistributionChannelWriter(writer, getClusterClientOptions(), topologyRefreshScheduler); ClusterPubSubConnectionProvider pooledClusterConnectionProvider = new ClusterPubSubConnectionProvider<>(this, diff --git a/src/main/java/io/lettuce/core/protocol/CommandExpiryWriter.java b/src/main/java/io/lettuce/core/protocol/CommandExpiryWriter.java index 49608e2c24..63bb652abc 100644 --- a/src/main/java/io/lettuce/core/protocol/CommandExpiryWriter.java +++ b/src/main/java/io/lettuce/core/protocol/CommandExpiryWriter.java @@ -15,7 +15,7 @@ */ package io.lettuce.core.protocol; -import static io.lettuce.core.TimeoutOptions.TimeoutSource; +import static io.lettuce.core.TimeoutOptions.*; import java.time.Duration; import java.util.Collection; @@ -25,9 +25,9 @@ import java.util.concurrent.TimeUnit; import io.lettuce.core.ClientOptions; -import io.lettuce.core.internal.ExceptionFactory; import io.lettuce.core.RedisChannelWriter; import io.lettuce.core.TimeoutOptions; +import io.lettuce.core.internal.ExceptionFactory; import io.lettuce.core.internal.LettuceAssert; import io.lettuce.core.resource.ClientResources; @@ -41,7 +41,7 @@ */ public class CommandExpiryWriter implements RedisChannelWriter { - private final RedisChannelWriter writer; + private final RedisChannelWriter delegate; private final TimeoutSource source; @@ -56,18 +56,18 @@ public class CommandExpiryWriter implements RedisChannelWriter { /** * Create a new {@link CommandExpiryWriter}. * - * @param writer must not be {@code null}. + * @param delegate must not be {@code null}. * @param clientOptions must not be {@code null}. * @param clientResources must not be {@code null}. */ - public CommandExpiryWriter(RedisChannelWriter writer, ClientOptions clientOptions, ClientResources clientResources) { + public CommandExpiryWriter(RedisChannelWriter delegate, ClientOptions clientOptions, ClientResources clientResources) { - LettuceAssert.notNull(writer, "RedisChannelWriter must not be null"); + LettuceAssert.notNull(delegate, "RedisChannelWriter must not be null"); LettuceAssert.isTrue(isSupported(clientOptions), "Command timeout not enabled"); LettuceAssert.notNull(clientResources, "ClientResources must not be null"); TimeoutOptions timeoutOptions = clientOptions.getTimeoutOptions(); - this.writer = writer; + this.delegate = delegate; this.source = timeoutOptions.getSource(); this.applyConnectionTimeout = timeoutOptions.isApplyConnectionTimeout(); this.timeUnit = source.getTimeUnit(); @@ -96,24 +96,24 @@ private static boolean isSupported(TimeoutOptions timeoutOptions) { @Override public void setConnectionFacade(ConnectionFacade connectionFacade) { - writer.setConnectionFacade(connectionFacade); + delegate.setConnectionFacade(connectionFacade); } @Override public ClientResources getClientResources() { - return writer.getClientResources(); + return delegate.getClientResources(); } @Override public void setAutoFlushCommands(boolean autoFlush) { - writer.setAutoFlushCommands(autoFlush); + delegate.setAutoFlushCommands(autoFlush); } @Override public RedisCommand write(RedisCommand command) { potentiallyExpire(command, getExecutorService()); - return writer.write(command); + return delegate.write(command); } @Override @@ -125,33 +125,37 @@ public RedisCommand write(RedisCommand command) { potentiallyExpire(command, executorService); } - return writer.write(redisCommands); + return delegate.write(redisCommands); } @Override public void flushCommands() { - writer.flushCommands(); + delegate.flushCommands(); } @Override public void close() { - writer.close(); + delegate.close(); } @Override public CompletableFuture closeAsync() { - return writer.closeAsync(); + return delegate.closeAsync(); } @Override public void reset() { - writer.reset(); + delegate.reset(); } public void setTimeout(Duration timeout) { this.timeout = timeUnit.convert(timeout.toNanos(), TimeUnit.NANOSECONDS); } + public RedisChannelWriter getDelegate() { + return delegate; + } + private ScheduledExecutorService getExecutorService() { return this.executorService; } diff --git a/src/test/java/io/lettuce/core/cluster/ClusterDistributionChannelWriterUnitTests.java b/src/test/java/io/lettuce/core/cluster/ClusterDistributionChannelWriterUnitTests.java index 118c20b495..c14c8cc5b8 100644 --- a/src/test/java/io/lettuce/core/cluster/ClusterDistributionChannelWriterUnitTests.java +++ b/src/test/java/io/lettuce/core/cluster/ClusterDistributionChannelWriterUnitTests.java @@ -33,8 +33,10 @@ import org.mockito.junit.jupiter.MockitoSettings; import org.mockito.quality.Strictness; -import io.lettuce.core.RedisChannelWriter; +import io.lettuce.core.ClientOptions; +import io.lettuce.core.CommandListenerWriter; import io.lettuce.core.StatefulRedisConnectionImpl; +import io.lettuce.core.TimeoutOptions; import io.lettuce.core.api.StatefulRedisConnection; import io.lettuce.core.cluster.ClusterConnectionProvider.Intent; import io.lettuce.core.codec.StringCodec; @@ -44,7 +46,9 @@ import io.lettuce.core.protocol.AsyncCommand; import io.lettuce.core.protocol.Command; import io.lettuce.core.protocol.CommandArgs; +import io.lettuce.core.protocol.CommandExpiryWriter; import io.lettuce.core.protocol.CommandType; +import io.lettuce.core.protocol.DefaultEndpoint; import io.lettuce.core.protocol.RedisCommand; import io.lettuce.core.resource.ClientResources; @@ -59,7 +63,7 @@ class ClusterDistributionChannelWriterUnitTests { @Mock - private RedisChannelWriter defaultWriter; + private DefaultEndpoint defaultWriter; @Mock private EventBus eventBus; @@ -171,6 +175,21 @@ void shouldWriteCommandListWhenAsking() { verifyWriteCommandCountWhenRedirecting(false); } + @Test + void shouldDisconnectWrappedEndpoint() { + + CommandListenerWriter listenerWriter = new CommandListenerWriter(defaultWriter, Collections.emptyList()); + CommandExpiryWriter expiryWriter = new CommandExpiryWriter(listenerWriter, + ClientOptions.builder().timeoutOptions(TimeoutOptions.enabled()).build(), clientResources); + + ClusterDistributionChannelWriter writer = new ClusterDistributionChannelWriter(expiryWriter, ClientOptions.create(), + clusterEventListener); + + writer.disconnectDefaultEndpoint(); + + verify(defaultWriter).disconnect(); + } + @Test void shouldWriteOneCommandWhenMoved() { verifyWriteCommandCountWhenRedirecting(true);