Skip to content

Commit

Permalink
Fix CommandExpiryWriter unwrapping in ClusterDistributionChannelWriter
Browse files Browse the repository at this point in the history
  • Loading branch information
mp911de committed Dec 3, 2021
1 parent 8d032fe commit 8b96b88
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import io.lettuce.core.output.StatusOutput;
import io.lettuce.core.protocol.Command;
import io.lettuce.core.protocol.CommandArgs;
import io.lettuce.core.protocol.CommandExpiryWriter;
import io.lettuce.core.protocol.CommandKeyword;
import io.lettuce.core.protocol.CommandType;
import io.lettuce.core.protocol.ConnectionFacade;
Expand Down Expand Up @@ -76,7 +77,7 @@ class ClusterDistributionChannelWriter implements RedisChannelWriter {

private volatile Partitions partitions;

ClusterDistributionChannelWriter(ClientOptions clientOptions, RedisChannelWriter defaultWriter,
ClusterDistributionChannelWriter(RedisChannelWriter defaultWriter, ClientOptions clientOptions,
ClusterEventListener clusterEventListener) {

if (clientOptions instanceof ClusterClientOptions) {
Expand Down Expand Up @@ -426,16 +427,29 @@ public CompletableFuture<Void> closeAsync() {
}

public void disconnectDefaultEndpoint() {
unwrapDefaultEndpoint().disconnect();
}

DefaultEndpoint defaultEndpoint;
private DefaultEndpoint unwrapDefaultEndpoint() {

if (defaultWriter instanceof CommandListenerWriter) {
defaultEndpoint = (DefaultEndpoint) ((CommandListenerWriter) defaultWriter).getDelegate();
} else {
defaultEndpoint = ((DefaultEndpoint) defaultWriter);
RedisChannelWriter writer = this.defaultWriter;

while (!(writer instanceof DefaultEndpoint)) {

if (writer instanceof CommandListenerWriter) {
writer = ((CommandListenerWriter) writer).getDelegate();
continue;
}

if (writer instanceof CommandExpiryWriter) {
writer = ((CommandExpiryWriter) writer).getDelegate();
continue;
}

throw new IllegalStateException(String.format("Cannot unwrap defaultWriter %s into DefaultEndpoint", writer));
}

defaultEndpoint.disconnect();
return (DefaultEndpoint) writer;
}

@Override
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/io/lettuce/core/cluster/RedisClusterClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ private <K, V> CompletableFuture<StatefulRedisClusterConnection<K, V>> connectCl
writer = new CommandListenerWriter(writer, getCommandListeners());
}

ClusterDistributionChannelWriter clusterWriter = new ClusterDistributionChannelWriter(getClusterClientOptions(), writer,
ClusterDistributionChannelWriter clusterWriter = new ClusterDistributionChannelWriter(writer, getClusterClientOptions(),
topologyRefreshScheduler);
PooledClusterConnectionProvider<K, V> pooledClusterConnectionProvider = new PooledClusterConnectionProvider<>(this,
clusterWriter, codec, topologyRefreshScheduler);
Expand Down Expand Up @@ -748,7 +748,7 @@ private <K, V> CompletableFuture<StatefulRedisClusterPubSubConnection<K, V>> con
writer = new CommandListenerWriter(writer, getCommandListeners());
}

ClusterDistributionChannelWriter clusterWriter = new ClusterDistributionChannelWriter(getClusterClientOptions(), writer,
ClusterDistributionChannelWriter clusterWriter = new ClusterDistributionChannelWriter(writer, getClusterClientOptions(),
topologyRefreshScheduler);

ClusterPubSubConnectionProvider<K, V> pooledClusterConnectionProvider = new ClusterPubSubConnectionProvider<>(this,
Expand Down
36 changes: 20 additions & 16 deletions src/main/java/io/lettuce/core/protocol/CommandExpiryWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package io.lettuce.core.protocol;

import static io.lettuce.core.TimeoutOptions.TimeoutSource;
import static io.lettuce.core.TimeoutOptions.*;

import java.time.Duration;
import java.util.Collection;
Expand All @@ -25,9 +25,9 @@
import java.util.concurrent.TimeUnit;

import io.lettuce.core.ClientOptions;
import io.lettuce.core.internal.ExceptionFactory;
import io.lettuce.core.RedisChannelWriter;
import io.lettuce.core.TimeoutOptions;
import io.lettuce.core.internal.ExceptionFactory;
import io.lettuce.core.internal.LettuceAssert;
import io.lettuce.core.resource.ClientResources;

Expand All @@ -41,7 +41,7 @@
*/
public class CommandExpiryWriter implements RedisChannelWriter {

private final RedisChannelWriter writer;
private final RedisChannelWriter delegate;

private final TimeoutSource source;

Expand All @@ -56,18 +56,18 @@ public class CommandExpiryWriter implements RedisChannelWriter {
/**
* Create a new {@link CommandExpiryWriter}.
*
* @param writer must not be {@code null}.
* @param delegate must not be {@code null}.
* @param clientOptions must not be {@code null}.
* @param clientResources must not be {@code null}.
*/
public CommandExpiryWriter(RedisChannelWriter writer, ClientOptions clientOptions, ClientResources clientResources) {
public CommandExpiryWriter(RedisChannelWriter delegate, ClientOptions clientOptions, ClientResources clientResources) {

LettuceAssert.notNull(writer, "RedisChannelWriter must not be null");
LettuceAssert.notNull(delegate, "RedisChannelWriter must not be null");
LettuceAssert.isTrue(isSupported(clientOptions), "Command timeout not enabled");
LettuceAssert.notNull(clientResources, "ClientResources must not be null");

TimeoutOptions timeoutOptions = clientOptions.getTimeoutOptions();
this.writer = writer;
this.delegate = delegate;
this.source = timeoutOptions.getSource();
this.applyConnectionTimeout = timeoutOptions.isApplyConnectionTimeout();
this.timeUnit = source.getTimeUnit();
Expand Down Expand Up @@ -96,24 +96,24 @@ private static boolean isSupported(TimeoutOptions timeoutOptions) {

@Override
public void setConnectionFacade(ConnectionFacade connectionFacade) {
writer.setConnectionFacade(connectionFacade);
delegate.setConnectionFacade(connectionFacade);
}

@Override
public ClientResources getClientResources() {
return writer.getClientResources();
return delegate.getClientResources();
}

@Override
public void setAutoFlushCommands(boolean autoFlush) {
writer.setAutoFlushCommands(autoFlush);
delegate.setAutoFlushCommands(autoFlush);
}

@Override
public <K, V, T> RedisCommand<K, V, T> write(RedisCommand<K, V, T> command) {

potentiallyExpire(command, getExecutorService());
return writer.write(command);
return delegate.write(command);
}

@Override
Expand All @@ -125,33 +125,37 @@ public <K, V, T> RedisCommand<K, V, T> write(RedisCommand<K, V, T> command) {
potentiallyExpire(command, executorService);
}

return writer.write(redisCommands);
return delegate.write(redisCommands);
}

@Override
public void flushCommands() {
writer.flushCommands();
delegate.flushCommands();
}

@Override
public void close() {
writer.close();
delegate.close();
}

@Override
public CompletableFuture<Void> closeAsync() {
return writer.closeAsync();
return delegate.closeAsync();
}

@Override
public void reset() {
writer.reset();
delegate.reset();
}

public void setTimeout(Duration timeout) {
this.timeout = timeUnit.convert(timeout.toNanos(), TimeUnit.NANOSECONDS);
}

public RedisChannelWriter getDelegate() {
return delegate;
}

private ScheduledExecutorService getExecutorService() {
return this.executorService;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;

import io.lettuce.core.RedisChannelWriter;
import io.lettuce.core.ClientOptions;
import io.lettuce.core.CommandListenerWriter;
import io.lettuce.core.StatefulRedisConnectionImpl;
import io.lettuce.core.TimeoutOptions;
import io.lettuce.core.api.StatefulRedisConnection;
import io.lettuce.core.cluster.ClusterConnectionProvider.Intent;
import io.lettuce.core.codec.StringCodec;
Expand All @@ -44,7 +46,9 @@
import io.lettuce.core.protocol.AsyncCommand;
import io.lettuce.core.protocol.Command;
import io.lettuce.core.protocol.CommandArgs;
import io.lettuce.core.protocol.CommandExpiryWriter;
import io.lettuce.core.protocol.CommandType;
import io.lettuce.core.protocol.DefaultEndpoint;
import io.lettuce.core.protocol.RedisCommand;
import io.lettuce.core.resource.ClientResources;

Expand All @@ -59,7 +63,7 @@
class ClusterDistributionChannelWriterUnitTests {

@Mock
private RedisChannelWriter defaultWriter;
private DefaultEndpoint defaultWriter;

@Mock
private EventBus eventBus;
Expand Down Expand Up @@ -171,6 +175,21 @@ void shouldWriteCommandListWhenAsking() {
verifyWriteCommandCountWhenRedirecting(false);
}

@Test
void shouldDisconnectWrappedEndpoint() {

CommandListenerWriter listenerWriter = new CommandListenerWriter(defaultWriter, Collections.emptyList());
CommandExpiryWriter expiryWriter = new CommandExpiryWriter(listenerWriter,
ClientOptions.builder().timeoutOptions(TimeoutOptions.enabled()).build(), clientResources);

ClusterDistributionChannelWriter writer = new ClusterDistributionChannelWriter(expiryWriter, ClientOptions.create(),
clusterEventListener);

writer.disconnectDefaultEndpoint();

verify(defaultWriter).disconnect();
}

@Test
void shouldWriteOneCommandWhenMoved() {
verifyWriteCommandCountWhenRedirecting(true);
Expand Down

0 comments on commit 8b96b88

Please sign in to comment.