diff --git a/src/main/java/com/lambdaworks/redis/protocol/CommandEncoder.java b/src/main/java/com/lambdaworks/redis/protocol/CommandEncoder.java index 60f43d14f3..29f007bd65 100644 --- a/src/main/java/com/lambdaworks/redis/protocol/CommandEncoder.java +++ b/src/main/java/com/lambdaworks/redis/protocol/CommandEncoder.java @@ -3,6 +3,7 @@ package com.lambdaworks.redis.protocol; import java.nio.charset.Charset; +import java.util.Collection; import io.netty.buffer.ByteBuf; import io.netty.channel.Channel; @@ -18,7 +19,7 @@ * @author Mark Paluch */ @ChannelHandler.Sharable -public class CommandEncoder extends MessageToByteEncoder> { +public class CommandEncoder extends MessageToByteEncoder { private static final InternalLogger logger = InternalLoggerFactory.getInstance(CommandEncoder.class); @@ -43,12 +44,28 @@ public CommandEncoder(boolean preferDirect) { } @Override - protected void encode(ChannelHandlerContext ctx, RedisCommand msg, ByteBuf out) throws Exception { + protected void encode(ChannelHandlerContext ctx, Object msg, ByteBuf out) throws Exception { - msg.encode(out); + if (msg instanceof RedisCommand) { + RedisCommand command = (RedisCommand) msg; + encode(ctx, out, command); + } + + if (msg instanceof Collection) { + Collection> commands = (Collection>) msg; + for (RedisCommand command : commands) { + if (command.isCancelled()) { + continue; + } + encode(ctx, out, command); + } + } + } + private void encode(ChannelHandlerContext ctx, ByteBuf out, RedisCommand command) { + command.encode(out); if (debugEnabled) { - logger.debug("{} writing command {}", logPrefix(ctx.channel()), msg); + logger.debug("{} writing command {}", logPrefix(ctx.channel()), command); if (traceEnabled) { logger.trace("{} Sent: {}", logPrefix(ctx.channel()), out.toString(Charset.defaultCharset()).trim()); } @@ -60,5 +77,4 @@ private String logPrefix(Channel channel) { buffer.append('[').append(ChannelLogDescriptor.logDescriptor(channel)).append(']'); return buffer.toString(); } - } diff --git a/src/main/java/com/lambdaworks/redis/protocol/CommandHandler.java b/src/main/java/com/lambdaworks/redis/protocol/CommandHandler.java index 61533fcc5f..76f0335ceb 100644 --- a/src/main/java/com/lambdaworks/redis/protocol/CommandHandler.java +++ b/src/main/java/com/lambdaworks/redis/protocol/CommandHandler.java @@ -6,10 +6,12 @@ import java.nio.charset.Charset; import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.Queue; import java.util.concurrent.locks.ReentrantLock; +import com.google.common.collect.ImmutableList; import com.lambdaworks.redis.ClientOptions; import com.lambdaworks.redis.ConnectionEvents; import com.lambdaworks.redis.RedisChannelHandler; @@ -44,7 +46,10 @@ public class CommandHandler extends ChannelDuplexHandler implements RedisC protected ClientOptions clientOptions; protected Queue> queue; - protected Queue> commandBuffer = new ArrayDeque>(); + + // all access to the commandBuffer is synchronized + protected Queue> commandBuffer = newCommandBuffer(); + protected ByteBuf buffer; protected RedisStateMachine rsm; protected Channel channel; @@ -177,43 +182,42 @@ public > C write(C command) { writeLock.lock(); Channel channel = this.channel; - if (channel != null && isConnected() && channel.isActive()) { - if (debugEnabled) { - logger.debug("{} write() writeAndFlush Command {}", logPrefix(), command); - } + if (autoFlushCommands) { - if (reliability == Reliability.AT_MOST_ONCE) { - // cancel on exceptions and remove from queue, because there is no housekeeping - channel.write(command).addListener(new AtMostOnceWriteListener(command, queue)); - } + if (channel != null && isConnected() && channel.isActive()) { + if (debugEnabled) { + logger.debug("{} write() writeAndFlush Command {}", logPrefix(), command); + } - if (reliability == Reliability.AT_LEAST_ONCE) { - // commands are ok to stay within the queue, reconnect will retrigger them - channel.write(command).addListener(WRITE_LOG_LISTENER); - } + if (reliability == Reliability.AT_MOST_ONCE) { + // cancel on exceptions and remove from queue, because there is no housekeeping + channel.writeAndFlush(command).addListener(new AtMostOnceWriteListener(command, queue)); + } - if (autoFlushCommands) { - channel.flush(); - } + if (reliability == Reliability.AT_LEAST_ONCE) { + // commands are ok to stay within the queue, reconnect will retrigger them + channel.writeAndFlush(command).addListener(WRITE_LOG_LISTENER); + } + } else { - } else { + if (commandBuffer.contains(command) || queue.contains(command)) { + return command; + } - if (commandBuffer.contains(command) || queue.contains(command)) { - return command; - } + if (connectionError != null) { + if (debugEnabled) { + logger.debug("{} write() completing Command {} due to connection error", logPrefix(), command); + } + command.completeExceptionally(connectionError); - if (connectionError != null) { - if (debugEnabled) { - logger.debug("{} write() completing Command {} due to connection error", logPrefix(), command); + return command; } - command.completeExceptionally(connectionError); + bufferCommand(command); return command; } - if (debugEnabled) { - logger.debug("{} write() buffering Command {}", logPrefix(), command); - } - commandBuffer.add(command); + } else { + bufferCommand(command); } } finally { writeLock.unlock(); @@ -225,6 +229,13 @@ public > C write(C command) { return command; } + private void bufferCommand(RedisCommand command) { + if (debugEnabled) { + logger.debug("{} write() buffering Command {}", logPrefix(), command); + } + commandBuffer.add(command); + } + private boolean isConnected() { synchronized (lifecycleState) { return lifecycleState.ordinal() >= LifecycleState.CONNECTED.ordinal() @@ -233,9 +244,27 @@ private boolean isConnected() { } @Override + @SuppressWarnings("rawtypes") public void flushCommands() { - if (channel != null && isConnected() && channel.isActive()) { - channel.flush(); + if (channel != null && isConnected()) { + Queue> queuedCommands; + try { + writeLock.lock(); + queuedCommands = (Queue) commandBuffer; + commandBuffer = newCommandBuffer(); + } finally { + writeLock.unlock(); + } + + if (reliability == Reliability.AT_MOST_ONCE) { + // cancel on exceptions and remove from queue, because there is no housekeeping + channel.writeAndFlush(queuedCommands).addListener(new AtMostOnceWriteListener(queuedCommands, this.queue)); + } + + if (reliability == Reliability.AT_LEAST_ONCE) { + // commands are ok to stay within the queue, reconnect will retrigger them + channel.writeAndFlush(queuedCommands).addListener(WRITE_LOG_LISTENER); + } } } @@ -248,7 +277,24 @@ public void flushCommands() { @SuppressWarnings("unchecked") public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { - final RedisCommand cmd = (RedisCommand) msg; + if (msg instanceof Collection) { + Collection> commands = (Collection>) msg; + for (RedisCommand command : commands) { + queueCommand(promise, command); + } + ctx.write(commands, promise); + return; + } + + RedisCommand cmd = (RedisCommand) msg; + queueCommand(promise, cmd); + ctx.write(cmd, promise); + } + + private void queueCommand(ChannelPromise promise, RedisCommand cmd) throws Exception { + if (cmd.isCancelled()) { + return; + } try { if (cmd.getOutput() == null) { @@ -261,8 +307,6 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) promise.setFailure(e); throw e; } - - ctx.write(cmd, promise); } @Override @@ -299,7 +343,7 @@ public void run() { } protected void executeQueuedCommands(ChannelHandlerContext ctx) { - List> tmp = new ArrayList<>(queue.size() + commandBuffer.size()); + Queue> tmp = newCommandBuffer(); try { writeLock.lock(); @@ -309,7 +353,7 @@ protected void executeQueuedCommands(ChannelHandlerContext ctx) { tmp.addAll(queue); queue.clear(); - commandBuffer.clear(); + commandBuffer = tmp; if (debugEnabled) { logger.debug("{} executeQueuedCommands {} command(s) queued", logPrefix(), tmp.size()); @@ -328,23 +372,10 @@ protected void executeQueuedCommands(ChannelHandlerContext ctx) { } setStateIfNotClosed(LifecycleState.ACTIVE); - for (RedisCommand cmd : tmp) { - if (!cmd.isCancelled()) { - - if (debugEnabled) { - logger.debug("{} channelActive() triggering command {}", logPrefix(), cmd); - } - - write(cmd); - } - } - - tmp.clear(); - + flushCommands(); } finally { writeLock.unlock(); } - } /** @@ -525,6 +556,10 @@ protected String logPrefix() { return logPrefix = buffer.toString(); } + private ArrayDeque> newCommandBuffer() { + return new ArrayDeque>(512); + } + public enum LifecycleState { NOT_CONNECTED, REGISTERED, CONNECTED, ACTIVATING, ACTIVE, DISCONNECTED, DEACTIVATING, DEACTIVATED, CLOSED, } @@ -535,11 +570,17 @@ private enum Reliability { private static class AtMostOnceWriteListener implements ChannelFutureListener { - private final RedisCommand sentCommand; + private final Collection> sentCommands; private final Queue queue; + @SuppressWarnings("rawtypes") public AtMostOnceWriteListener(RedisCommand sentCommand, Queue queue) { - this.sentCommand = sentCommand; + this.sentCommands = (Collection) ImmutableList.of(sentCommand); + this.queue = queue; + } + + public AtMostOnceWriteListener(Collection> sentCommand, Queue queue) { + this.sentCommands = sentCommand; this.queue = queue; } @@ -547,8 +588,10 @@ public AtMostOnceWriteListener(RedisCommand sentCommand, Queue queue public void operationComplete(ChannelFuture future) throws Exception { future.await(); if (future.cause() != null) { - sentCommand.completeExceptionally(future.cause()); - queue.remove(sentCommand); + for (RedisCommand sentCommand : sentCommands) { + sentCommand.completeExceptionally(future.cause()); + } + queue.removeAll(sentCommands); } } } @@ -558,7 +601,6 @@ public void operationComplete(ChannelFuture future) throws Exception { * */ static class WriteLogListener implements GenericFutureListener> { - @Override public void operationComplete(Future future) throws Exception { if (!future.isSuccess() && !(future.cause() instanceof ClosedChannelException)) diff --git a/src/test/java/com/lambdaworks/redis/protocol/CommandHandlerTest.java b/src/test/java/com/lambdaworks/redis/protocol/CommandHandlerTest.java index 66c46a5ce9..e0dc8e3afe 100644 --- a/src/test/java/com/lambdaworks/redis/protocol/CommandHandlerTest.java +++ b/src/test/java/com/lambdaworks/redis/protocol/CommandHandlerTest.java @@ -12,6 +12,10 @@ import java.util.Queue; import java.util.concurrent.Future; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelPromise; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; @@ -34,11 +38,11 @@ @RunWith(MockitoJUnitRunner.class) public class CommandHandlerTest { - private Queue> q = new ArrayDeque>(10); + private Queue> q = new ArrayDeque<>(10); - private CommandHandler sut = new CommandHandler(new ClientOptions.Builder().build(), q); + private CommandHandler sut = new CommandHandler<>(new ClientOptions.Builder().build(), q); - private Command command = new Command(CommandType.APPEND, + private Command command = new Command<>(CommandType.APPEND, new StatusOutput(new Utf8StringCodec()), null); @Mock @@ -47,6 +51,9 @@ public class CommandHandlerTest { @Mock private Channel channel; + @Mock + private ByteBufAllocator byteBufAllocator; + @Mock private ChannelPipeline pipeline; @@ -56,6 +63,7 @@ public class CommandHandlerTest { @Before public void before() throws Exception { when(context.channel()).thenReturn(channel); + when(context.alloc()).thenReturn(byteBufAllocator); when(channel.pipeline()).thenReturn(pipeline); when(channel.eventLoop()).thenReturn(eventLoop); when(eventLoop.submit(any(Runnable.class))).thenAnswer(new Answer() { @@ -66,11 +74,15 @@ public Future answer(InvocationOnMock invocation) throws Throwable { return null; } }); + + when(channel.write(any())).thenAnswer(invocation -> new DefaultChannelPromise(channel)); + + when(channel.writeAndFlush(any())).thenAnswer(invocation -> new DefaultChannelPromise(channel)); } @Test public void testChannelActive() throws Exception { - sut.setState(CommandHandler.LifecycleState.REGISTERED); + sut.channelRegistered(context); sut.channelActive(context);