diff --git a/src/main/java/com/lambdaworks/redis/protocol/CommandHandler.java b/src/main/java/com/lambdaworks/redis/protocol/CommandHandler.java index 4b4fcbccba..e8b5623a7c 100644 --- a/src/main/java/com/lambdaworks/redis/protocol/CommandHandler.java +++ b/src/main/java/com/lambdaworks/redis/protocol/CommandHandler.java @@ -647,8 +647,25 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) return; } + if (msg instanceof List) { + + List> batch = (List>) msg; + + if (batch.size() == 1) { + + writeSingleCommand(ctx, batch.get(0), promise); + return; + } + + writeBatch(ctx, batch, promise); + return; + } + if (msg instanceof Collection) { - writeBatch(ctx, (Collection>) msg, promise); + + Collection> batch = (Collection>) msg; + + writeBatch(ctx, batch, promise); } } @@ -666,25 +683,22 @@ private void writeSingleCommand(ChannelHandlerContext ctx, RedisCommand private void writeBatch(ChannelHandlerContext ctx, Collection> batch, ChannelPromise promise) throws Exception { - Collection> toWrite = batch; - int commandsToWrite = 0; + Collection> deduplicated = new LinkedHashSet<>(batch.size(), 1); - boolean cancelledCommands = false; - for (RedisCommand command : batch) { + for (RedisCommand command : batch) { - if (!isWriteable(command)) { - cancelledCommands = true; - break; + if (isWriteable(command) && !deduplicated.add(command)) { + deduplicated.remove(command); + command.completeExceptionally(new RedisException( + "Attempting to write duplicate command that is already enqueued: " + command)); } - - commandsToWrite++; } try { - validateWrite(commandsToWrite); + validateWrite(deduplicated.size()); } catch (Exception e) { - for (RedisCommand redisCommand : toWrite) { + for (RedisCommand redisCommand : deduplicated) { redisCommand.completeExceptionally(e); } @@ -692,26 +706,12 @@ private void writeBatch(ChannelHandlerContext ctx, Collection(batch.size()); - - for (RedisCommand command : batch) { - - if (!isWriteable(command)) { - continue; - } - - toWrite.add(command); - } - } - - for (RedisCommand command : toWrite) { + for (RedisCommand command : deduplicated) { addToStack(command, promise); } - if (!toWrite.isEmpty()) { - ctx.write(toWrite, promise); + if (!deduplicated.isEmpty()) { + ctx.write(deduplicated, promise); } } @@ -732,10 +732,6 @@ private void addToStack(RedisCommand command, ChannelPromise promise) { RedisCommand commandToUse = potentiallyWrapLatencyCommand(command); - if (stack.contains(command)) { - throw new RedisException("Attempting to write duplicate command that is already enqueued: " + command); - } - if (promise.getClass() == VOID_PROMISE_CLASS) { stack.add(commandToUse); } else { diff --git a/src/test/java/com/lambdaworks/redis/protocol/CommandHandlerTest.java b/src/test/java/com/lambdaworks/redis/protocol/CommandHandlerTest.java index a73e7cb979..2f4acc295b 100644 --- a/src/test/java/com/lambdaworks/redis/protocol/CommandHandlerTest.java +++ b/src/test/java/com/lambdaworks/redis/protocol/CommandHandlerTest.java @@ -498,6 +498,18 @@ public void shouldCancelCommandOnQueueBatchFailure() throws Exception { verify(commandMock).completeExceptionally(exception); } + @Test + public void shouldFailOnDuplicateCommands() throws Exception { + + Command commandMock = mock(Command.class); + + ChannelPromise channelPromise = new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE); + sut.write(context, Arrays.asList(commandMock, commandMock), channelPromise); + + assertThat(stack).isEmpty(); + verify(commandMock).completeExceptionally(any(RedisException.class)); + } + @Test public void shouldWriteActiveCommands() throws Exception { @@ -520,17 +532,29 @@ public void shouldNotWriteCancelledCommandBatch() throws Exception { } @Test - public void shouldWriteActiveCommandsInBatch() throws Exception { + public void shouldWriteSingleActiveCommandsInBatch() throws Exception { when(promise.isSuccess()).thenReturn(true); List> commands = Arrays.asList(command); sut.write(context, commands, promise); - verify(context).write(commands, promise); + verify(context).write(command, promise); assertThat(stack).hasSize(1); } + @Test + public void shouldWriteActiveCommandsInBatch() throws Exception { + + Command anotherCommand = new Command<>(CommandType.APPEND, + new StatusOutput<>(StringCodec.UTF8), null); + + List> commands = Arrays.asList(command, anotherCommand); + sut.write(context, commands, promise); + + verify(context).write(any(Set.class), eq(promise)); + } + @Test @SuppressWarnings("unchecked") public void shouldWriteActiveCommandsInMixedBatch() throws Exception { @@ -543,7 +567,7 @@ public void shouldWriteActiveCommandsInMixedBatch() throws Exception { sut.write(context, Arrays.asList(command, command2), promise); - ArgumentCaptor captor = ArgumentCaptor.forClass(List.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(Collection.class); verify(context).write(captor.capture(), any()); assertThat(captor.getValue()).containsOnly(command2);