diff --git a/src/main/java/com/lambdaworks/redis/cluster/PooledClusterConnectionProvider.java b/src/main/java/com/lambdaworks/redis/cluster/PooledClusterConnectionProvider.java index c2ee0c80b3..8e15a92a8f 100644 --- a/src/main/java/com/lambdaworks/redis/cluster/PooledClusterConnectionProvider.java +++ b/src/main/java/com/lambdaworks/redis/cluster/PooledClusterConnectionProvider.java @@ -94,7 +94,9 @@ public void close() { this.connections.invalidateAll(); resetPartitions(); for (RedisAsyncConnection kvRedisAsyncConnection : copy.values()) { - kvRedisAsyncConnection.close(); + if (kvRedisAsyncConnection.isOpen()) { + kvRedisAsyncConnection.close(); + } } } diff --git a/src/main/java/com/lambdaworks/redis/protocol/CommandHandler.java b/src/main/java/com/lambdaworks/redis/protocol/CommandHandler.java index a594541b96..f51cd4b173 100644 --- a/src/main/java/com/lambdaworks/redis/protocol/CommandHandler.java +++ b/src/main/java/com/lambdaworks/redis/protocol/CommandHandler.java @@ -3,24 +3,16 @@ package com.lambdaworks.redis.protocol; import java.nio.charset.Charset; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.List; import java.util.Queue; -import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.locks.ReentrantLock; -import com.lambdaworks.redis.ClientOptions; -import com.lambdaworks.redis.ConnectionEvents; -import com.lambdaworks.redis.RedisChannelHandler; -import com.lambdaworks.redis.RedisChannelWriter; -import com.lambdaworks.redis.RedisException; +import com.lambdaworks.redis.*; import io.netty.buffer.ByteBuf; -import io.netty.channel.Channel; -import io.netty.channel.ChannelDuplexHandler; -import io.netty.channel.ChannelHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelPromise; +import io.netty.channel.*; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -38,13 +30,14 @@ public class CommandHandler extends ChannelDuplexHandler implements RedisC protected ClientOptions clientOptions; protected Queue> queue; - protected Queue> commandBuffer = new LinkedBlockingQueue>(); + protected Queue> commandBuffer = new ArrayDeque>(); protected ByteBuf buffer; protected RedisStateMachine rsm; + private LifecycleState lifecycleState = LifecycleState.NOT_CONNECTED; + private Object stateLock = new Object(); private Channel channel; - private boolean closed; - private boolean connected; + private RedisChannelHandler redisChannelHandler; private final ReentrantLock writeLock = new ReentrantLock(); private Throwable connectionError; @@ -79,19 +72,24 @@ public CommandHandler(ClientOptions clientOptions, Queue> */ @Override public void channelRegistered(ChannelHandlerContext ctx) throws Exception { - closed = false; + setState(LifecycleState.REGISTERED); buffer = ctx.alloc().heapBuffer(); rsm = new RedisStateMachine(); - channel = ctx.channel(); + synchronized (stateLock) { + channel = ctx.channel(); + } } @Override public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { releaseBuffer(); - if (closed) { + + if (lifecycleState == LifecycleState.CLOSED) { cancelCommands("Connection closed"); } - channel = null; + synchronized (stateLock) { + channel = null; + } } /** @@ -133,7 +131,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf buffer) throws Interrup @Override public RedisCommand write(RedisCommand command) { - if (closed) { + if (lifecycleState == LifecycleState.CLOSED) { throw new RedisException("Connection is closed"); } try { @@ -141,9 +139,10 @@ public RedisCommand write(RedisCommand command) { * This lock causes safety for connection activation and somehow netty gets more stable and predictable performance * than without a lock and all threads are hammering towards writeAndFlush. */ + writeLock.lock(); Channel channel = this.channel; - if (channel != null && connected && channel.isActive()) { + if (channel != null && isConnected() && channel.isActive()) { if (debugEnabled) { logger.debug("{} write() writeAndFlush Command {}", logPrefix(), command); } @@ -159,10 +158,14 @@ public RedisCommand write(RedisCommand command) { */ if (!channel.isActive()) { - write(command); + return write(command); } } else { + if (commandBuffer.contains(command) || queue.contains(command)) { + return command; + } + if (connectionError != null) { if (debugEnabled) { logger.debug("{} write() completing Command {} due to connection error", logPrefix(), command); @@ -187,6 +190,11 @@ public RedisCommand write(RedisCommand command) { return command; } + private boolean isConnected() { + return lifecycleState.ordinal() >= LifecycleState.CONNECTED.ordinal() + && lifecycleState.ordinal() <= LifecycleState.DISCONNECTED.ordinal(); + } + /** * * @see io.netty.channel.ChannelDuplexHandler#write(io.netty.channel.ChannelHandlerContext, java.lang.Object, @@ -214,8 +222,7 @@ public void channelActive(ChannelHandlerContext ctx) throws Exception { if (debugEnabled) { logger.debug("{} channelActive()", logPrefix()); } - connected = true; - closed = false; + setStateIfNotClosed(LifecycleState.CONNECTED); try { executeQueuedCommands(ctx); @@ -251,14 +258,19 @@ protected void executeQueuedCommands(ChannelHandlerContext ctx) { if (debugEnabled) { logger.debug("{} executeQueuedCommands {} command(s) queued", logPrefix(), queue.size()); } - channel = ctx.channel(); + + synchronized (stateLock) { + channel = ctx.channel(); + } if (redisChannelHandler != null) { if (debugEnabled) { logger.debug("{} activating channel handler", logPrefix()); } + setStateIfNotClosed(LifecycleState.ACTIVATING); redisChannelHandler.activated(); } + setStateIfNotClosed(LifecycleState.ACTIVE); for (RedisCommand cmd : tmp) { if (!cmd.isCancelled()) { @@ -289,14 +301,21 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception { if (debugEnabled) { logger.debug("{} channelInactive()", logPrefix()); } - connected = false; + setStateIfNotClosed(LifecycleState.DISCONNECTED); if (redisChannelHandler != null) { if (debugEnabled) { logger.debug("{} deactivating channel handler", logPrefix()); } + setStateIfNotClosed(LifecycleState.DEACTIVATING); redisChannelHandler.deactivated(); } + setStateIfNotClosed(LifecycleState.DEACTIVATED); + + if (buffer != null) { + rsm.reset(); + buffer.clear(); + } if (debugEnabled) { logger.debug("{} channelInactive() done", logPrefix()); @@ -304,6 +323,18 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception { super.channelInactive(ctx); } + protected void setStateIfNotClosed(LifecycleState lifecycleState) { + if (this.lifecycleState != LifecycleState.CLOSED) { + setState(lifecycleState); + } + } + + protected void setState(LifecycleState lifecycleState) { + synchronized (stateLock) { + this.lifecycleState = lifecycleState; + } + } + private void cancelCommands(String message) { int size = 0; if (queue != null) { @@ -342,7 +373,7 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E command.complete(); } - if (channel == null || !connected) { + if (channel == null || !channel.isActive() || !isConnected()) { connectionError = cause; return; } @@ -359,11 +390,11 @@ public void close() { logger.debug("{} close()", logPrefix()); } - if (closed) { + if (lifecycleState == LifecycleState.CLOSED) { return; } - closed = true; + setStateIfNotClosed(LifecycleState.CLOSED); Channel currentChannel = this.channel; if (currentChannel != null) { currentChannel.pipeline().fireUserEventTriggered(new ConnectionEvents.PrepareClose()); @@ -380,7 +411,7 @@ private void releaseBuffer() { } public boolean isClosed() { - return closed; + return lifecycleState == LifecycleState.CLOSED; } /** @@ -425,4 +456,9 @@ private String logPrefix() { return logPrefix = buffer.toString(); } + enum LifecycleState { + + NOT_CONNECTED, REGISTERED, CONNECTED, ACTIVATING, ACTIVE, DISCONNECTED, DEACTIVATING, DEACTIVATED, CLOSED, + } + } diff --git a/src/test/java/com/lambdaworks/redis/protocol/CommandHandlerTest.java b/src/test/java/com/lambdaworks/redis/protocol/CommandHandlerTest.java index df915ebc88..9ef79a8a7c 100644 --- a/src/test/java/com/lambdaworks/redis/protocol/CommandHandlerTest.java +++ b/src/test/java/com/lambdaworks/redis/protocol/CommandHandlerTest.java @@ -2,14 +2,15 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Matchers.any; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; +import java.util.ArrayDeque; +import java.util.Queue; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; import com.lambdaworks.redis.ClientOptions; +import com.lambdaworks.redis.RedisException; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; @@ -24,10 +25,13 @@ @RunWith(MockitoJUnitRunner.class) public class CommandHandlerTest { - private BlockingQueue> q = new ArrayBlockingQueue>(10); + private Queue> q = new ArrayDeque>(10); private CommandHandler sut = new CommandHandler(new ClientOptions.Builder().build(), q); + private Command command = new Command(CommandType.APPEND, + new StatusOutput(new Utf8StringCodec()), null); + @Mock private ChannelHandlerContext context; @@ -37,7 +41,10 @@ public class CommandHandlerTest { @Test public void testExceptionChannelActive() throws Exception { + sut.setState(CommandHandler.LifecycleState.ACTIVE); + when(context.channel()).thenReturn(channel); + when(channel.isActive()).thenReturn(true); sut.channelActive(context); sut.exceptionCaught(context, new Exception()); @@ -46,19 +53,20 @@ public void testExceptionChannelActive() throws Exception { @Test public void testExceptionChannelInactive() throws Exception { + sut.setState(CommandHandler.LifecycleState.DISCONNECTED); sut.exceptionCaught(context, new Exception()); verify(context, never()).fireExceptionCaught(any(Exception.class)); } @Test public void testExceptionWithQueue() throws Exception { + sut.setState(CommandHandler.LifecycleState.ACTIVE); q.clear(); when(context.channel()).thenReturn(channel); sut.channelActive(context); + when(channel.isActive()).thenReturn(true); - Command command = new Command(CommandType.APPEND, - new StatusOutput(new Utf8StringCodec()), null); q.add(command); sut.exceptionCaught(context, new Exception()); @@ -68,4 +76,21 @@ public void testExceptionWithQueue() throws Exception { verify(context).fireExceptionCaught(any(Exception.class)); } + @Test(expected = RedisException.class) + public void testWriteWhenClosed() throws Exception { + + sut.setState(CommandHandler.LifecycleState.CLOSED); + + sut.write(command); + } + + @Test + public void testExceptionWhenClosed() throws Exception { + + sut.setState(CommandHandler.LifecycleState.CLOSED); + + sut.exceptionCaught(context, new Exception()); + verifyZeroInteractions(context); + } + }