diff --git a/src/main/java/com/lambdaworks/redis/protocol/AsyncCommand.java b/src/main/java/com/lambdaworks/redis/protocol/AsyncCommand.java index bd0306bccc..be551af3eb 100644 --- a/src/main/java/com/lambdaworks/redis/protocol/AsyncCommand.java +++ b/src/main/java/com/lambdaworks/redis/protocol/AsyncCommand.java @@ -206,15 +206,8 @@ public boolean equals(Object o) { return false; } - RedisCommand left = command; - while (left instanceof DecoratedCommand) { - left = CommandWrapper.unwrap(left); - } - - RedisCommand right = (RedisCommand) o; - while (right instanceof DecoratedCommand) { - right = CommandWrapper.unwrap(right); - } + RedisCommand left = CommandWrapper.unwrap(command); + RedisCommand right = CommandWrapper.unwrap((RedisCommand) o); return left == right; } @@ -222,10 +215,7 @@ public boolean equals(Object o) { @Override public int hashCode() { - RedisCommand toHash = command; - while (toHash instanceof DecoratedCommand) { - toHash = CommandWrapper.unwrap(toHash); - } + RedisCommand toHash = CommandWrapper.unwrap(command); return toHash != null ? toHash.hashCode() : 0; } diff --git a/src/main/java/com/lambdaworks/redis/protocol/CommandHandler.java b/src/main/java/com/lambdaworks/redis/protocol/CommandHandler.java index 924b929f41..c3a15505de 100644 --- a/src/main/java/com/lambdaworks/redis/protocol/CommandHandler.java +++ b/src/main/java/com/lambdaworks/redis/protocol/CommandHandler.java @@ -712,12 +712,20 @@ private void addToStack(RedisCommand command, ChannelPromise promise) { + ". Commands are not accepted until the stack size drops."); } - command = potentiallyWrapLatencyCommand(command); + RedisCommand commandToUse = potentiallyWrapLatencyCommand(command); if (stack.contains(command)) { throw new RedisException("Attempting to write duplicate command that is already enqueued: " + command); + } + + if (promise.isVoid()) { + stack.add(commandToUse); } else { - stack.add(command); + promise.addListener(future -> { + if (future.isSuccess()) { + stack.add(commandToUse); + } + }); } } } catch (RuntimeException e) { @@ -1169,23 +1177,14 @@ private class ListenerSupport { this.sentCommands = sentCommands; } - void dequeue(boolean success) { + void dequeue() { if (sentCommand != null) { - QUEUE_SIZE.decrementAndGet(CommandHandler.this); - if (!success) { - CommandHandler.this.stack.remove(sentCommand); - CommandHandler.this.disconnectedBuffer.remove(sentCommand); - } } if (sentCommands != null) { QUEUE_SIZE.addAndGet(CommandHandler.this, -sentCommands.size()); - if (!success) { - CommandHandler.this.stack.removeAll(sentCommands); - CommandHandler.this.disconnectedBuffer.removeAll(sentCommands); - } } } @@ -1217,7 +1216,7 @@ private class AtMostOnceWriteListener extends ListenerSupport implements Channel @Override public void operationComplete(ChannelFuture future) throws Exception { - dequeue(true); + dequeue(); if (future.cause() != null) { complete(future.cause()); @@ -1245,7 +1244,7 @@ public void operationComplete(Future future) throws Exception { Throwable cause = future.cause(); boolean success = future.isSuccess(); - dequeue(success); + dequeue(); if (!success) { Channel channel = CommandHandler.this.channel; @@ -1256,7 +1255,7 @@ public void operationComplete(Future future) throws Exception { } } - if (!future.isSuccess() && !(cause instanceof ClosedChannelException)) { + if (!success && !(cause instanceof ClosedChannelException)) { String message = "Unexpected exception during request: {}"; InternalLogLevel logLevel = InternalLogLevel.WARN; diff --git a/src/test/java/com/lambdaworks/redis/cluster/ClusterPartiallyDownTest.java b/src/test/java/com/lambdaworks/redis/cluster/ClusterPartiallyDownTest.java index aece8071a1..4df0ddec72 100644 --- a/src/test/java/com/lambdaworks/redis/cluster/ClusterPartiallyDownTest.java +++ b/src/test/java/com/lambdaworks/redis/cluster/ClusterPartiallyDownTest.java @@ -18,6 +18,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Fail.fail; +import java.io.IOException; import java.net.ConnectException; import java.util.ArrayList; import java.util.HashSet; @@ -96,7 +97,7 @@ public void operateOnPartiallyDownCluster() throws Exception { connection.sync().get(key_10439); fail("Missing RedisException"); } catch (RedisException e) { - assertThat(e).hasRootCauseInstanceOf(ConnectException.class); + assertThat(e).hasRootCauseInstanceOf(IOException.class); } connection.close(); @@ -133,7 +134,7 @@ public void partitionNodesAreOffline() throws Exception { redisClusterClient.connect(); fail("Missing RedisConnectionException"); } catch (RedisConnectionException e) { - assertThat(e).hasRootCauseInstanceOf(ConnectException.class); + assertThat(e).hasRootCauseInstanceOf(IOException.class); } } } diff --git a/src/test/java/com/lambdaworks/redis/protocol/CommandHandlerTest.java b/src/test/java/com/lambdaworks/redis/protocol/CommandHandlerTest.java index 8f65454709..40bacf8402 100644 --- a/src/test/java/com/lambdaworks/redis/protocol/CommandHandlerTest.java +++ b/src/test/java/com/lambdaworks/redis/protocol/CommandHandlerTest.java @@ -56,6 +56,8 @@ import edu.umd.cs.mtc.TestFramework; import io.netty.buffer.ByteBuf; import io.netty.channel.*; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; import io.netty.util.concurrent.ImmediateEventExecutor; @RunWith(MockitoJUnitRunner.class) @@ -87,6 +89,9 @@ public class CommandHandlerTest { @Mock private RedisChannelHandler channelHandler; + @Mock + private ChannelPromise promise; + @BeforeClass public static void beforeClass() { LoggerContext ctx = (LoggerContext) LogManager.getContext(); @@ -129,13 +134,21 @@ public void before() throws Exception { stack.addAll((Collection) invocation.getArguments()[0]); } - return new DefaultChannelPromise(channel); + return promise; }); sut = new CommandHandler<>(ClientOptions.create(), clientResources); sut.setRedisChannelHandler(channelHandler); disconnectedBuffer = (Queue) ReflectionTestUtils.getField(sut, "disconnectedBuffer"); stack = (Queue) ReflectionTestUtils.getField(sut, "stack"); + + when(promise.addListener(any())).then(invocation -> { + + GenericFutureListener> listener = invocation.getArgument(0); + listener.operationComplete(promise); + + return null; + }); } @Test @@ -145,7 +158,6 @@ public void testChannelActive() throws Exception { sut.channelActive(context); verify(pipeline).fireUserEventTriggered(any(ConnectionEvents.Activated.class)); - } @Test @@ -173,6 +185,8 @@ public void testChannelActiveFailureShouldCancelCommands() throws Exception { @Test public void testChannelActiveWithBufferedAndQueuedCommands() throws Exception { + when(promise.isSuccess()).thenReturn(true); + Command bufferedCommand = new Command<>(CommandType.GET, new StatusOutput<>( new Utf8StringCodec()), null); @@ -261,6 +275,8 @@ public void testChannelActiveReplayBufferedCommands() throws Exception { disconnectedBuffer.add(bufferedCommand1); disconnectedBuffer.add(bufferedCommand2); + when(promise.isSuccess()).thenReturn(true); + sut.channelRegistered(context); sut.channelActive(context); @@ -470,12 +486,25 @@ public void shouldCancelCommandOnQueueBatchFailure() throws Exception { verify(commandMock).completeExceptionally(exception); } + @Test + public void shouldWriteActiveCommandsForVoidPromise() throws Exception { + + when(promise.isVoid()).thenReturn(true); + + sut.write(context, command, promise); + + verify(context).write(command, promise); + assertThat(stack).hasSize(1).allMatch(o -> o instanceof LatencyMeteredCommand); + } + @Test public void shouldWriteActiveCommands() throws Exception { - sut.write(context, command, null); + when(promise.isSuccess()).thenReturn(true); - verify(context).write(command, null); + sut.write(context, command, promise); + + verify(context).write(command, promise); assertThat(stack).hasSize(1).allMatch(o -> o instanceof LatencyMeteredCommand); } @@ -483,7 +512,7 @@ public void shouldWriteActiveCommands() throws Exception { public void shouldNotWriteCancelledCommandBatch() throws Exception { command.cancel(); - sut.write(context, Arrays.asList(command), null); + sut.write(context, Arrays.asList(command), promise); verifyZeroInteractions(context); assertThat(disconnectedBuffer).isEmpty(); @@ -492,10 +521,12 @@ public void shouldNotWriteCancelledCommandBatch() throws Exception { @Test public void shouldWriteActiveCommandsInBatch() throws Exception { + when(promise.isSuccess()).thenReturn(true); + List> commands = Arrays.asList(command); - sut.write(context, commands, null); + sut.write(context, commands, promise); - verify(context).write(commands, null); + verify(context).write(commands, promise); assertThat(stack).hasSize(1); } @@ -503,12 +534,13 @@ public void shouldWriteActiveCommandsInBatch() throws Exception { @SuppressWarnings("unchecked") public void shouldWriteActiveCommandsInMixedBatch() throws Exception { + when(promise.isSuccess()).thenReturn(true); + Command command2 = new Command<>(CommandType.APPEND, new StatusOutput<>(new Utf8StringCodec()), null); - command.cancel(); - sut.write(context, Arrays.asList(command, command2), null); + sut.write(context, Arrays.asList(command, command2), promise); ArgumentCaptor captor = ArgumentCaptor.forClass(List.class); verify(context).write(captor.capture(), any()); @@ -532,10 +564,11 @@ public void shouldIgnoreNonReadableBuffers() throws Exception { @Test(timeout = 5000) public void shouldRebuildHugeQueue() throws Exception { + when(promise.isSuccess()).thenReturn(true); + for (int i = 0; i < 500000; i++) { Command command = new Command<>(CommandType.SET, new StatusOutput<>(StringCodec.UTF8)); - disconnectedBuffer.add(new AsyncCommand<>(command)); }