From ecbe36555a2aa32513aa3e6e886ceadea0b50b00 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Mon, 19 Mar 2018 15:20:12 +0100 Subject: [PATCH] Inspect Pub/Sub responses for interleaved messages #724 Lettuce now inspects Redis responses via PubSubCommandHandler and ReplayOutput whether a received response is a Pub/Sub message or whether the response belongs to a command on the protocol stack. Introspection is required as Redis responses may contain interleaved messages that do not belong to a command or may arrive before the command response. Previously, interleaved messages could get used to complete commands on the protocol stack which causes a defunct protocol state. --- .../redis/output/ReplayOutput.java | 197 ++++++++++++++++++ .../redis/protocol/CommandHandler.java | 69 ++++-- .../redis/pubsub/PubSubCommandHandler.java | 151 +++++++++++++- .../lambdaworks/redis/ProtectedModeTests.java | 15 ++ .../redis/output/ReplayOutputTest.java | 82 ++++++++ .../pubsub/PubSubCommandHandlerTest.java | 77 +++++++ 6 files changed, 573 insertions(+), 18 deletions(-) create mode 100644 src/main/java/com/lambdaworks/redis/output/ReplayOutput.java create mode 100644 src/test/java/com/lambdaworks/redis/output/ReplayOutputTest.java diff --git a/src/main/java/com/lambdaworks/redis/output/ReplayOutput.java b/src/main/java/com/lambdaworks/redis/output/ReplayOutput.java new file mode 100644 index 0000000000..72524eb172 --- /dev/null +++ b/src/main/java/com/lambdaworks/redis/output/ReplayOutput.java @@ -0,0 +1,197 @@ +/* + * Copyright 2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lambdaworks.redis.output; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +import com.lambdaworks.redis.codec.RedisCodec; +import com.lambdaworks.redis.codec.StringCodec; + +/** + * Replayable {@link CommandOutput} capturing output signals to replay these on a target {@link CommandOutput}. Replay is useful + * when the response requires inspection prior to dispatching the actual output to a command target. + * + * @author Mark Paluch + * @since 4.4.4 + */ +public class ReplayOutput extends CommandOutput> { + + /** + * Initialize a new instance that encodes and decodes keys and values using the supplied codec. + */ + public ReplayOutput() { + super((RedisCodec) StringCodec.ASCII, new ArrayList<>()); + } + + @Override + public void set(ByteBuffer bytes) { + output.add(new BulkString(bytes)); + } + + @Override + public void set(long integer) { + output.add(new Integer(integer)); + } + + @Override + public void setError(ByteBuffer error) { + error.mark(); + output.add(new ErrorBytes(error)); + error.reset(); + super.setError(error); + } + + @Override + public void setError(String error) { + output.add(new ErrorString(error)); + super.setError(error); + } + + @Override + public void complete(int depth) { + output.add(new Complete(depth)); + } + + @Override + public void multi(int count) { + output.add(new Multi(count)); + } + + /** + * Replay all captured signals on a {@link CommandOutput}. + * + * @param target the target {@link CommandOutput}. + */ + public void replay(CommandOutput target) { + + for (Signal signal : output) { + signal.replay(target); + } + } + + /** + * Encapsulates a replayable decoding signal. + */ + public static abstract class Signal { + + /** + * Replay the signal on a {@link CommandOutput}. + * + * @param target + */ + protected abstract void replay(CommandOutput target); + } + + abstract static class BulkStringSupport extends Signal { + + final ByteBuffer message; + + BulkStringSupport(ByteBuffer message) { + + if (message != null) { + + // need to copy the buffer to prevent buffer lifecycle mismatch + this.message = ByteBuffer.allocate(message.remaining()); + this.message.put(message); + this.message.rewind(); + } else { + this.message = null; + } + } + } + + public static class BulkString extends BulkStringSupport { + + BulkString(ByteBuffer message) { + super(message); + } + + @Override + protected void replay(CommandOutput target) { + target.set(message); + } + } + + static class Integer extends Signal { + + final long message; + + Integer(long message) { + this.message = message; + } + + @Override + protected void replay(CommandOutput target) { + target.set(message); + } + } + + public static class ErrorBytes extends BulkStringSupport { + + ErrorBytes(ByteBuffer message) { + super(message); + } + + @Override + protected void replay(CommandOutput target) { + target.setError(message); + } + } + + static class ErrorString extends Signal { + + final String message; + + ErrorString(String message) { + this.message = message; + } + + @Override + protected void replay(CommandOutput target) { + target.setError(message); + } + } + + static class Multi extends Signal { + + final int count; + + Multi(int count) { + this.count = count; + } + + @Override + protected void replay(CommandOutput target) { + target.multi(count); + } + } + + static class Complete extends Signal { + + final int depth; + + public Complete(int depth) { + this.depth = depth; + } + + @Override + protected void replay(CommandOutput target) { + target.complete(depth); + } + } +} diff --git a/src/main/java/com/lambdaworks/redis/protocol/CommandHandler.java b/src/main/java/com/lambdaworks/redis/protocol/CommandHandler.java index 767f5a6fbb..166dbbeae7 100644 --- a/src/main/java/com/lambdaworks/redis/protocol/CommandHandler.java +++ b/src/main/java/com/lambdaworks/redis/protocol/CommandHandler.java @@ -29,6 +29,7 @@ import com.lambdaworks.redis.internal.LettuceClassUtils; import com.lambdaworks.redis.internal.LettuceFactories; import com.lambdaworks.redis.internal.LettuceSets; +import com.lambdaworks.redis.output.CommandOutput; import com.lambdaworks.redis.resource.ClientResources; import io.netty.buffer.ByteBuf; @@ -145,6 +146,10 @@ public CommandHandler(ClientOptions clientOptions, ClientResources clientResourc boundedQueue = clientOptions.getRequestQueueSize() != Integer.MAX_VALUE; } + protected Deque> getStack() { + return stack; + } + @Override public void setRedisChannelHandler(RedisChannelHandler redisChannelHandler) { this.redisChannelHandler = redisChannelHandler; @@ -301,19 +306,21 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf buffer) { } if (isProtectedMode(command)) { - onProtectedMode(command.getOutput().getError()); + onProtectedMode(getCommandOutput(command).getError()); } else { - stack.poll(); + if (canComplete(command)) { + stack.poll(); - try { - command.complete(); - } catch (Exception e) { - logger.warn("{} Unexpected exception during command completion: {}", logPrefix, e.toString(), e); + try { + complete(command); + } catch (Exception e) { + logger.warn("{} Unexpected exception during command completion: {}", logPrefix, e.toString(), e); + } } } - afterComplete(ctx, command); + afterDecode(ctx, command); } if (buffer.refCnt() != 0) { @@ -383,19 +390,57 @@ private void onProtectedMode(String message) { cancelCommands(message); } + /** + * Decoding hook: Can the buffer be decoded to a command. + * + * @param buffer + * @return + */ + protected boolean canDecode(ByteBuf buffer) { + return !stack.isEmpty() && buffer.isReadable(); + } + + /** + * Decoding hook: Can the command be completed. + * + * @param command + * @return + */ + protected boolean canComplete(RedisCommand command) { + return true; + } + + /** + * Decoding hook: Complete a command. + * + * @param command + * @see RedisCommand#complete() + */ + protected void complete(RedisCommand command) { + command.complete(); + } + /** * Hook method called after command completion. * * @param ctx * @param command */ - protected void afterComplete(ChannelHandlerContext ctx, RedisCommand command) { + protected void afterDecode(ChannelHandlerContext ctx, RedisCommand command) { } - protected boolean canDecode(ByteBuf buffer) { - return !stack.isEmpty() && buffer.isReadable(); + /** + * Decoding hook: Retrieve {@link CommandOutput} for {@link RedisCommand} decoding. + * + * @param command + * @return + * @see RedisCommand#getOutput() + */ + protected CommandOutput getCommandOutput(RedisCommand command) { + return command.getOutput(); } + private boolean decode(ByteBuf buffer, RedisCommand command) { if (latencyMetricsEnabled && command instanceof WithLatency) { @@ -405,7 +450,7 @@ private boolean decode(ByteBuf buffer, RedisCommand command) { withLatency.firstResponse(nanoTime()); } - if (!rsm.decode(buffer, command, command.getOutput())) { + if (!rsm.decode(buffer, command, getCommandOutput(command))) { return false; } @@ -414,7 +459,7 @@ private boolean decode(ByteBuf buffer, RedisCommand command) { return true; } - return rsm.decode(buffer, command, command.getOutput()); + return rsm.decode(buffer, command, getCommandOutput(command)); } private void recordLatency(WithLatency withLatency, ProtocolKeyword commandType) { diff --git a/src/main/java/com/lambdaworks/redis/pubsub/PubSubCommandHandler.java b/src/main/java/com/lambdaworks/redis/pubsub/PubSubCommandHandler.java index 6d20861b37..f5a22f1568 100644 --- a/src/main/java/com/lambdaworks/redis/pubsub/PubSubCommandHandler.java +++ b/src/main/java/com/lambdaworks/redis/pubsub/PubSubCommandHandler.java @@ -15,8 +15,15 @@ */ package com.lambdaworks.redis.pubsub; +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import java.util.Deque; + import com.lambdaworks.redis.ClientOptions; import com.lambdaworks.redis.codec.RedisCodec; +import com.lambdaworks.redis.codec.StringCodec; +import com.lambdaworks.redis.output.CommandOutput; +import com.lambdaworks.redis.output.ReplayOutput; import com.lambdaworks.redis.protocol.CommandHandler; import com.lambdaworks.redis.protocol.RedisCommand; import com.lambdaworks.redis.resource.ClientResources; @@ -26,8 +33,11 @@ import io.netty.channel.ChannelHandlerContext; /** - * A netty {@link ChannelHandler} responsible for writing redis pub/sub commands and reading the response stream from the - * server. + * A netty {@link ChannelHandler} responsible for writing Redis Pub/Sub commands and reading the response stream from the + * server. {@link PubSubCommandHandler} accounts for Pub/Sub message notification calling back + * {@link ChannelHandlerContext#fireChannelRead(Object)}. Redis responses can be interleaved in the sense that a response + * contains a Pub/Sub message first, then a command response. Possible interleave is introspected via + * {@link ResponseHeaderReplayOutput} and decoding hooks. * * @param Key type. * @param Value type. @@ -37,6 +47,9 @@ public class PubSubCommandHandler extends CommandHandler { private final RedisCodec codec; + private final Deque> queue = new ArrayDeque<>(); + + private ResponseHeaderReplayOutput replay; private PubSubOutput output; /** @@ -54,12 +67,32 @@ public PubSubCommandHandler(ClientOptions clientOptions, ClientResources clientR this.output = new PubSubOutput<>(codec); } + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + + replay = null; + queue.clear(); + + super.channelInactive(ctx); + } + + @SuppressWarnings("unchecked") @Override protected void decode(ChannelHandlerContext ctx, ByteBuf buffer) { - super.decode(ctx, buffer); + if (!getStack().isEmpty()) { + super.decode(ctx, buffer); + } - while (buffer.isReadable()) { + ReplayOutput replay; + while ((replay = queue.poll()) != null) { + + replay.replay(output); + ctx.fireChannelRead(output); + output = new PubSubOutput<>(codec); + } + + while (super.getStack().isEmpty() && buffer.isReadable()) { if (!rsm.decode(buffer, output)) { return; @@ -67,8 +100,9 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf buffer) { ctx.fireChannelRead(output); output = new PubSubOutput<>(codec); - buffer.discardReadBytes(); } + + buffer.discardReadBytes(); } @Override @@ -77,9 +111,114 @@ protected boolean canDecode(ByteBuf buffer) { } @Override - protected void afterComplete(ChannelHandlerContext ctx, RedisCommand command) { + protected boolean canComplete(RedisCommand command) { + + if (isPubSubMessage(replay)) { + + queue.add(replay); + replay = null; + return false; + } + + return super.canComplete(command); + } + + @Override + protected void complete(RedisCommand command) { + + if (replay != null && command.getOutput() != null) { + try { + replay.replay(command.getOutput()); + } catch (Exception e) { + command.completeExceptionally(e); + } + replay = null; + } + + super.complete(command); + } + + /** + * Check whether {@link ResponseHeaderReplayOutput} contains a Pub/Sub message that requires Pub/Sub dispatch instead of to + * be used as Command output. + * + * @param replay + * @return + */ + private static boolean isPubSubMessage(ResponseHeaderReplayOutput replay) { + + if (replay == null) { + return false; + } + + String firstElement = replay.firstElement; + if (replay.multiCount != null && firstElement != null) { + + if (replay.multiCount == 3 && firstElement.equalsIgnoreCase(PubSubOutput.Type.message.name())) { + return true; + } + + if (replay.multiCount == 4 && firstElement.equalsIgnoreCase(PubSubOutput.Type.pmessage.name())) { + return true; + } + } + + return false; + } + + @Override + protected CommandOutput getCommandOutput(RedisCommand command) { + + if (getStack().isEmpty() || command.getOutput() == null) { + return super.getCommandOutput(command); + } + + if (replay == null) { + replay = new ResponseHeaderReplayOutput<>(); + } + + return replay; + } + + @Override + protected void afterDecode(ChannelHandlerContext ctx, RedisCommand command) { if (command.getOutput() instanceof PubSubOutput) { ctx.fireChannelRead(command.getOutput()); } } + + /** + * Inspectable {@link ReplayOutput} to investigate the first multi and string response elements. + * + * @param + * @param + */ + static class ResponseHeaderReplayOutput extends ReplayOutput { + + Integer multiCount; + String firstElement; + + @Override + public void set(ByteBuffer bytes) { + + if (firstElement == null && bytes != null && bytes.remaining() > 0) { + + bytes.mark(); + firstElement = StringCodec.ASCII.decodeKey(bytes); + bytes.reset(); + } + + super.set(bytes); + } + + @Override + public void multi(int count) { + + if (multiCount == null) { + multiCount = count; + } + + super.multi(count); + } + } } diff --git a/src/test/java/com/lambdaworks/redis/ProtectedModeTests.java b/src/test/java/com/lambdaworks/redis/ProtectedModeTests.java index 5eba6d9161..f8a8df389b 100644 --- a/src/test/java/com/lambdaworks/redis/ProtectedModeTests.java +++ b/src/test/java/com/lambdaworks/redis/ProtectedModeTests.java @@ -98,6 +98,21 @@ public void regularClientFailsOnFirstCommand() { } } + @Test + public void pubSubClientFailsOnFirstCommand() { + + try (StatefulRedisConnection connect = client.connectPubSub()) { + + connect.sync().ping(); + } catch (RedisException e) { + if (e.getCause() instanceof IOException) { + assertThat(e).hasCauseInstanceOf(IOException.class); + } else { + assertThat(e.getCause()).hasMessageContaining("DENIED"); + } + } + } + @Test public void regularClientFailsOnFirstCommandWithDelay() { diff --git a/src/test/java/com/lambdaworks/redis/output/ReplayOutputTest.java b/src/test/java/com/lambdaworks/redis/output/ReplayOutputTest.java new file mode 100644 index 0000000000..708e6b04dc --- /dev/null +++ b/src/test/java/com/lambdaworks/redis/output/ReplayOutputTest.java @@ -0,0 +1,82 @@ +/* + * Copyright 2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lambdaworks.redis.output; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.nio.ByteBuffer; +import java.util.Collections; + +import org.junit.Test; + +import com.lambdaworks.redis.codec.StringCodec; + +/** + * @author Mark Paluch + */ +public class ReplayOutputTest { + + @Test + public void shouldReplaySimpleCompletion() { + + ReplayOutput replay = new ReplayOutput<>(); + ValueOutput target = new ValueOutput<>(StringCodec.ASCII); + + replay.multi(1); + replay.set(ByteBuffer.wrap("foo".getBytes())); + replay.complete(1); + + replay.replay(target); + + assertThat(target.get()).isEqualTo("foo"); + } + + @Test + public void shouldReplayNestedCompletion() { + + ReplayOutput replay = new ReplayOutput<>(); + ArrayOutput target = new ArrayOutput<>(StringCodec.ASCII); + + replay.multi(1); + replay.multi(1); + replay.set(ByteBuffer.wrap("foo".getBytes())); + replay.complete(2); + + replay.multi(1); + replay.set(ByteBuffer.wrap("bar".getBytes())); + replay.complete(2); + replay.complete(1); + + replay.replay(target); + + assertThat(target.get().get(0)).isEqualTo(Collections.singletonList("foo")); + assertThat(target.get().get(1)).isEqualTo(Collections.singletonList("bar")); + } + + @Test + public void shouldDecodeErrorResponse() { + + ReplayOutput replay = new ReplayOutput<>(); + ValueOutput target = new ValueOutput<>(StringCodec.ASCII); + + replay.setError(ByteBuffer.wrap("foo".getBytes())); + + replay.replay(target); + + assertThat(replay.getError()).isEqualTo("foo"); + assertThat(target.getError()).isEqualTo("foo"); + } +} diff --git a/src/test/java/com/lambdaworks/redis/pubsub/PubSubCommandHandlerTest.java b/src/test/java/com/lambdaworks/redis/pubsub/PubSubCommandHandlerTest.java index 6fda7c77cb..a09c9f38ee 100644 --- a/src/test/java/com/lambdaworks/redis/pubsub/PubSubCommandHandlerTest.java +++ b/src/test/java/com/lambdaworks/redis/pubsub/PubSubCommandHandlerTest.java @@ -19,6 +19,7 @@ import static org.assertj.core.api.Fail.fail; import static org.mockito.Matchers.any; import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.util.Collection; @@ -27,6 +28,7 @@ import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; import org.springframework.test.util.ReflectionTestUtils; @@ -144,4 +146,79 @@ public void shouldCompleteCommandExceptionallyOnOutputFailure() throws Exception assertThat(ReflectionTestUtils.getField(command, "exception")).isInstanceOf(IllegalStateException.class); } + + @Test + public void shouldDecodeRegularCommand() throws Exception { + + sut.channelRegistered(context); + sut.channelActive(context); + stack.add(command); + + sut.channelRead(context, Unpooled.wrappedBuffer("+OK\r\n".getBytes())); + + assertThat(command.get()).isEqualTo("OK"); + } + + @Test + public void shouldDecodeTwoCommands() throws Exception { + + Command command1 = new Command<>(CommandType.APPEND, new StatusOutput<>(new Utf8StringCodec()), + null); + Command command2 = new Command<>(CommandType.APPEND, new StatusOutput<>(new Utf8StringCodec()), + null); + + sut.channelRegistered(context); + sut.channelActive(context); + stack.add(command1); + stack.add(command2); + + sut.channelRead(context, Unpooled.wrappedBuffer("+OK\r\n+YEAH\r\n".getBytes())); + + assertThat(command1.get()).isEqualTo("OK"); + assertThat(command2.get()).isEqualTo("YEAH"); + } + + @Test + public void shouldPropagatePubSubResponseToOutput() throws Exception { + + Command command1 = new Command<>(CommandType.APPEND, new StatusOutput<>(new Utf8StringCodec()), + null); + + sut.channelRegistered(context); + sut.channelActive(context); + stack.add(command1); + + sut.channelRead(context, Unpooled.wrappedBuffer("*3\r\n$7\r\nmessage\r\n$3\r\nfoo\r\n$3\r\nbar\r\n".getBytes())); + + assertThat(command1.isDone()).isFalse(); + + verify(context).fireChannelRead(any()); + } + + @Test + public void shouldPropagateInterleavedPubSubResponseToOutput() throws Exception { + + Command command1 = new Command<>(CommandType.APPEND, new StatusOutput<>(new Utf8StringCodec()), + null); + Command command2 = new Command<>(CommandType.APPEND, new StatusOutput<>(new Utf8StringCodec()), + null); + + sut.channelRegistered(context); + sut.channelActive(context); + stack.add(command1); + stack.add(command2); + + sut.channelRead(context, Unpooled + .wrappedBuffer("+OK\r\n*4\r\n$8\r\npmessage\r\n$1\r\n*\r\n$3\r\nfoo\r\n$3\r\nbar\r\n+YEAH\r\n".getBytes())); + + assertThat(command1.get()).isEqualTo("OK"); + assertThat(command2.get()).isEqualTo("YEAH"); + + ArgumentCaptor captor = ArgumentCaptor.forClass(PubSubOutput.class); + verify(context).fireChannelRead(captor.capture()); + + assertThat(captor.getValue().pattern()).isEqualTo("*"); + assertThat(captor.getValue().channel()).isEqualTo("foo"); + assertThat(captor.getValue().get()).isEqualTo("bar"); + } }