From fe18471193c5981a3925e5be51379b52e0b26de7 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Mon, 19 Mar 2018 15:20:12 +0100 Subject: [PATCH] Inspect Pub/Sub responses for interleaved messages #724 Lettuce now inspects Redis responses via PubSubCommandHandler and ReplayOutput whether a received response is a Pub/Sub message or whether the response belongs to a command on the protocol stack. Introspection is required as Redis responses may contain interleaved messages that do not belong to a command or may arrive before the command response. Previously, interleaved messages could get used to complete commands on the protocol stack which causes a defunct protocol state. --- .../io/lettuce/core/output/ReplayOutput.java | 197 ++++++++++++++++++ .../lettuce/core/protocol/CommandHandler.java | 57 ++++- .../core/pubsub/PubSubCommandHandler.java | 154 +++++++++++++- .../lettuce/core/output/ReplayOutputTest.java | 82 ++++++++ .../core/pubsub/PubSubCommandHandlerTest.java | 77 +++++++ 5 files changed, 551 insertions(+), 16 deletions(-) create mode 100644 src/main/java/io/lettuce/core/output/ReplayOutput.java create mode 100644 src/test/java/io/lettuce/core/output/ReplayOutputTest.java diff --git a/src/main/java/io/lettuce/core/output/ReplayOutput.java b/src/main/java/io/lettuce/core/output/ReplayOutput.java new file mode 100644 index 0000000000..15b5ab1baa --- /dev/null +++ b/src/main/java/io/lettuce/core/output/ReplayOutput.java @@ -0,0 +1,197 @@ +/* + * Copyright 2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.lettuce.core.output; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +import io.lettuce.core.codec.RedisCodec; +import io.lettuce.core.codec.StringCodec; + +/** + * Replayable {@link CommandOutput} capturing output signals to replay these on a target {@link CommandOutput}. Replay is useful + * when the response requires inspection prior to dispatching the actual output to a command target. + * + * @author Mark Paluch + * @since 5.0.3 + */ +public class ReplayOutput extends CommandOutput> { + + /** + * Initialize a new instance that encodes and decodes keys and values using the supplied codec. + */ + public ReplayOutput() { + super((RedisCodec) StringCodec.ASCII, new ArrayList<>()); + } + + @Override + public void set(ByteBuffer bytes) { + output.add(new BulkString(bytes)); + } + + @Override + public void set(long integer) { + output.add(new Integer(integer)); + } + + @Override + public void setError(ByteBuffer error) { + error.mark(); + output.add(new ErrorBytes(error)); + error.reset(); + super.setError(error); + } + + @Override + public void setError(String error) { + output.add(new ErrorString(error)); + super.setError(error); + } + + @Override + public void complete(int depth) { + output.add(new Complete(depth)); + } + + @Override + public void multi(int count) { + output.add(new Multi(count)); + } + + /** + * Replay all captured signals on a {@link CommandOutput}. + * + * @param target the target {@link CommandOutput}. + */ + public void replay(CommandOutput target) { + + for (Signal signal : output) { + signal.replay(target); + } + } + + /** + * Encapsulates a replayable decoding signal. + */ + public static abstract class Signal { + + /** + * Replay the signal on a {@link CommandOutput}. + * + * @param target + */ + protected abstract void replay(CommandOutput target); + } + + abstract static class BulkStringSupport extends Signal { + + final ByteBuffer message; + + BulkStringSupport(ByteBuffer message) { + + if (message != null) { + + // need to copy the buffer to prevent buffer lifecycle mismatch + this.message = ByteBuffer.allocate(message.remaining()); + this.message.put(message); + this.message.rewind(); + } else { + this.message = null; + } + } + } + + public static class BulkString extends BulkStringSupport { + + BulkString(ByteBuffer message) { + super(message); + } + + @Override + protected void replay(CommandOutput target) { + target.set(message); + } + } + + static class Integer extends Signal { + + final long message; + + Integer(long message) { + this.message = message; + } + + @Override + protected void replay(CommandOutput target) { + target.set(message); + } + } + + public static class ErrorBytes extends BulkStringSupport { + + ErrorBytes(ByteBuffer message) { + super(message); + } + + @Override + protected void replay(CommandOutput target) { + target.setError(message); + } + } + + static class ErrorString extends Signal { + + final String message; + + ErrorString(String message) { + this.message = message; + } + + @Override + protected void replay(CommandOutput target) { + target.setError(message); + } + } + + static class Multi extends Signal { + + final int count; + + Multi(int count) { + this.count = count; + } + + @Override + protected void replay(CommandOutput target) { + target.multi(count); + } + } + + static class Complete extends Signal { + + final int depth; + + public Complete(int depth) { + this.depth = depth; + } + + @Override + protected void replay(CommandOutput target) { + target.complete(depth); + } + } +} diff --git a/src/main/java/io/lettuce/core/protocol/CommandHandler.java b/src/main/java/io/lettuce/core/protocol/CommandHandler.java index da40a7e2d8..45f065e60d 100644 --- a/src/main/java/io/lettuce/core/protocol/CommandHandler.java +++ b/src/main/java/io/lettuce/core/protocol/CommandHandler.java @@ -411,7 +411,7 @@ private void addToStack(RedisCommand command, ChannelPromise promise) { if (command.getOutput() == null) { // fire&forget commands are excluded from metrics - command.complete(); + complete(command); } RedisCommand redisCommand = potentiallyWrapLatencyCommand(command); @@ -552,16 +552,18 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf buffer) throws Interrup onProtectedMode(command.getOutput().getError()); } else { - stack.poll(); + if (canComplete(command)) { + stack.poll(); - try { - command.complete(); - } catch (Exception e) { - logger.warn("{} Unexpected exception during request: {}", logPrefix, e.toString(), e); + try { + complete(command); + } catch (Exception e) { + logger.warn("{} Unexpected exception during request: {}", logPrefix, e.toString(), e); + } } } - afterComplete(ctx, command); + afterDecode(ctx, command); } if (buffer.refCnt() != 0) { @@ -569,10 +571,36 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf buffer) throws Interrup } } + /** + * Decoding hook: Can the buffer be decoded to a command. + * + * @param buffer + * @return + */ protected boolean canDecode(ByteBuf buffer) { return !stack.isEmpty() && buffer.isReadable(); } + /** + * Decoding hook: Can the command be completed. + * + * @param command + * @return + */ + protected boolean canComplete(RedisCommand command) { + return true; + } + + /** + * Decoding hook: Complete a command. + * + * @param command + * @see RedisCommand#complete() + */ + protected void complete(RedisCommand command) { + command.complete(); + } + private boolean decode(ChannelHandlerContext ctx, ByteBuf buffer, RedisCommand command) { if (latencyMetricsEnabled && command instanceof WithLatency) { @@ -596,7 +624,7 @@ private boolean decode(ChannelHandlerContext ctx, ByteBuf buffer, RedisCommand command) { - if (!decode(buffer, command, command.getOutput())) { + if (!decode(buffer, command, getCommandOutput(command))) { if (command instanceof DemandAware.Sink) { @@ -616,6 +644,17 @@ private boolean decode0(ChannelHandlerContext ctx, ByteBuf buffer, RedisCommand< return true; } + /** + * Decoding hook: Retrieve {@link CommandOutput} for {@link RedisCommand} decoding. + * + * @param command + * @return + * @see RedisCommand#getOutput() + */ + protected CommandOutput getCommandOutput(RedisCommand command) { + return command.getOutput(); + } + protected boolean decode(ByteBuf buffer, CommandOutput output) { return rsm.decode(buffer, output); } @@ -682,7 +721,7 @@ private void onProtectedMode(String message) { * @param ctx * @param command */ - protected void afterComplete(ChannelHandlerContext ctx, RedisCommand command) { + protected void afterDecode(ChannelHandlerContext ctx, RedisCommand command) { } private void recordLatency(WithLatency withLatency, ProtocolKeyword commandType) { diff --git a/src/main/java/io/lettuce/core/pubsub/PubSubCommandHandler.java b/src/main/java/io/lettuce/core/pubsub/PubSubCommandHandler.java index 0b2197a467..d816f6d5da 100644 --- a/src/main/java/io/lettuce/core/pubsub/PubSubCommandHandler.java +++ b/src/main/java/io/lettuce/core/pubsub/PubSubCommandHandler.java @@ -15,8 +15,15 @@ */ package io.lettuce.core.pubsub; +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import java.util.Deque; + import io.lettuce.core.ClientOptions; import io.lettuce.core.codec.RedisCodec; +import io.lettuce.core.codec.StringCodec; +import io.lettuce.core.output.CommandOutput; +import io.lettuce.core.output.ReplayOutput; import io.lettuce.core.protocol.CommandHandler; import io.lettuce.core.protocol.RedisCommand; import io.lettuce.core.resource.ClientResources; @@ -25,8 +32,11 @@ import io.netty.channel.ChannelHandlerContext; /** - * A netty {@link ChannelHandler} responsible for writing redis pub/sub commands and reading the response stream from the - * server. + * A netty {@link ChannelHandler} responsible for writing Redis Pub/Sub commands and reading the response stream from the + * server. {@link PubSubCommandHandler} accounts for Pub/Sub message notification calling back + * {@link PubSubEndpoint#notifyMessage(PubSubOutput)}. Redis responses can be interleaved in the sense that a response contains + * a Pub/Sub message first, then a command response. Possible interleave is introspected via {@link ResponseHeaderReplayOutput} + * and decoding hooks. * * @param Key type. * @param Value type. @@ -37,6 +47,9 @@ public class PubSubCommandHandler extends CommandHandler { private final PubSubEndpoint endpoint; private final RedisCodec codec; + private final Deque> queue = new ArrayDeque<>(); + + private ResponseHeaderReplayOutput replay; private PubSubOutput output; /** @@ -49,6 +62,7 @@ public class PubSubCommandHandler extends CommandHandler { */ public PubSubCommandHandler(ClientOptions clientOptions, ClientResources clientResources, RedisCodec codec, PubSubEndpoint endpoint) { + super(clientOptions, clientResources, endpoint); this.endpoint = endpoint; @@ -56,13 +70,32 @@ public PubSubCommandHandler(ClientOptions clientOptions, ClientResources clientR this.output = new PubSubOutput<>(codec); } + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + + replay = null; + queue.clear(); + + super.channelInactive(ctx); + } + @SuppressWarnings("unchecked") @Override protected void decode(ChannelHandlerContext ctx, ByteBuf buffer) throws InterruptedException { - super.decode(ctx, buffer); + if (!getStack().isEmpty()) { + super.decode(ctx, buffer); + } + + ReplayOutput replay; + while ((replay = queue.poll()) != null) { + + replay.replay(output); + endpoint.notifyMessage(output); + output = new PubSubOutput<>(codec); + } - while (buffer.isReadable()) { + while (super.getStack().isEmpty() && buffer.isReadable()) { if (!super.decode(buffer, output)) { return; @@ -70,9 +103,10 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf buffer) throws Interrup endpoint.notifyMessage(output); output = new PubSubOutput<>(codec); - - buffer.discardReadBytes(); } + + buffer.discardReadBytes(); + } @Override @@ -80,11 +114,117 @@ protected boolean canDecode(ByteBuf buffer) { return super.canDecode(buffer) && output.type() == null; } + @Override + protected boolean canComplete(RedisCommand command) { + + if (isPubSubMessage(replay)) { + + queue.add(replay); + replay = null; + return false; + } + + return super.canComplete(command); + } + + @Override + protected void complete(RedisCommand command) { + + if (replay != null && command.getOutput() != null) { + try { + replay.replay(command.getOutput()); + } catch (Exception e) { + command.completeExceptionally(e); + } + replay = null; + } + + super.complete(command); + } + + /** + * Check whether {@link ResponseHeaderReplayOutput} contains a Pub/Sub message that requires Pub/Sub dispatch instead of to + * be used as Command output. + * + * @param replay + * @return + */ + private static boolean isPubSubMessage(ResponseHeaderReplayOutput replay) { + + if (replay == null) { + return false; + } + + String firstElement = replay.firstElement; + if (replay.multiCount != null && firstElement != null) { + + if (replay.multiCount == 3 && firstElement.equalsIgnoreCase(PubSubOutput.Type.message.name())) { + return true; + } + + if (replay.multiCount == 4 && firstElement.equalsIgnoreCase(PubSubOutput.Type.pmessage.name())) { + return true; + } + } + + return false; + } + + @Override + protected CommandOutput getCommandOutput(RedisCommand command) { + + if (getStack().isEmpty() || command.getOutput() == null) { + return super.getCommandOutput(command); + } + + if (replay == null) { + replay = new ResponseHeaderReplayOutput<>(); + } + + return replay; + } + @Override @SuppressWarnings("unchecked") - protected void afterComplete(ChannelHandlerContext ctx, RedisCommand command) { + protected void afterDecode(ChannelHandlerContext ctx, RedisCommand command) { + if (command.getOutput() instanceof PubSubOutput) { endpoint.notifyMessage((PubSubOutput) command.getOutput()); } } + + /** + * Inspectable {@link ReplayOutput} to investigate the first multi and string response elements. + * + * @param + * @param + */ + static class ResponseHeaderReplayOutput extends ReplayOutput { + + Integer multiCount; + String firstElement; + + @Override + public void set(ByteBuffer bytes) { + + if (firstElement == null && bytes != null && bytes.remaining() > 0) { + + bytes.mark(); + firstElement = StringCodec.ASCII.decodeKey(bytes); + bytes.reset(); + } + + super.set(bytes); + } + + @Override + public void multi(int count) { + + if (multiCount == null) { + multiCount = count; + } + + super.multi(count); + } + } } diff --git a/src/test/java/io/lettuce/core/output/ReplayOutputTest.java b/src/test/java/io/lettuce/core/output/ReplayOutputTest.java new file mode 100644 index 0000000000..9686d38348 --- /dev/null +++ b/src/test/java/io/lettuce/core/output/ReplayOutputTest.java @@ -0,0 +1,82 @@ +/* + * Copyright 2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.lettuce.core.output; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.nio.ByteBuffer; +import java.util.Collections; + +import org.junit.Test; + +import io.lettuce.core.codec.StringCodec; + +/** + * @author Mark Paluch + */ +public class ReplayOutputTest { + + @Test + public void shouldReplaySimpleCompletion() { + + ReplayOutput replay = new ReplayOutput<>(); + ValueOutput target = new ValueOutput<>(StringCodec.ASCII); + + replay.multi(1); + replay.set(ByteBuffer.wrap("foo".getBytes())); + replay.complete(1); + + replay.replay(target); + + assertThat(target.get()).isEqualTo("foo"); + } + + @Test + public void shouldReplayNestedCompletion() { + + ReplayOutput replay = new ReplayOutput<>(); + ArrayOutput target = new ArrayOutput<>(StringCodec.ASCII); + + replay.multi(1); + replay.multi(1); + replay.set(ByteBuffer.wrap("foo".getBytes())); + replay.complete(2); + + replay.multi(1); + replay.set(ByteBuffer.wrap("bar".getBytes())); + replay.complete(2); + replay.complete(1); + + replay.replay(target); + + assertThat(target.get().get(0)).isEqualTo(Collections.singletonList("foo")); + assertThat(target.get().get(1)).isEqualTo(Collections.singletonList("bar")); + } + + @Test + public void shouldDecodeErrorResponse() { + + ReplayOutput replay = new ReplayOutput<>(); + ValueOutput target = new ValueOutput<>(StringCodec.ASCII); + + replay.setError(ByteBuffer.wrap("foo".getBytes())); + + replay.replay(target); + + assertThat(replay.getError()).isEqualTo("foo"); + assertThat(target.getError()).isEqualTo("foo"); + } +} diff --git a/src/test/java/io/lettuce/core/pubsub/PubSubCommandHandlerTest.java b/src/test/java/io/lettuce/core/pubsub/PubSubCommandHandlerTest.java index 468d2b3b6d..cac28b762f 100644 --- a/src/test/java/io/lettuce/core/pubsub/PubSubCommandHandlerTest.java +++ b/src/test/java/io/lettuce/core/pubsub/PubSubCommandHandlerTest.java @@ -17,6 +17,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Matchers.any; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.util.Queue; @@ -24,6 +25,7 @@ import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; import org.springframework.test.util.ReflectionTestUtils; @@ -109,4 +111,79 @@ public void shouldCompleteCommandExceptionallyOnOutputFailure() throws Exception assertThat(ReflectionTestUtils.getField(command, "exception")).isInstanceOf(IllegalStateException.class); } + + @Test + public void shouldDecodeRegularCommand() throws Exception { + + sut.channelRegistered(context); + sut.channelActive(context); + stack.add(command); + + sut.channelRead(context, Unpooled.wrappedBuffer("+OK\r\n".getBytes())); + + assertThat(command.get()).isEqualTo("OK"); + } + + @Test + public void shouldDecodeTwoCommands() throws Exception { + + Command command1 = new Command<>(CommandType.APPEND, new StatusOutput<>(new Utf8StringCodec()), + null); + Command command2 = new Command<>(CommandType.APPEND, new StatusOutput<>(new Utf8StringCodec()), + null); + + sut.channelRegistered(context); + sut.channelActive(context); + stack.add(command1); + stack.add(command2); + + sut.channelRead(context, Unpooled.wrappedBuffer("+OK\r\n+YEAH\r\n".getBytes())); + + assertThat(command1.get()).isEqualTo("OK"); + assertThat(command2.get()).isEqualTo("YEAH"); + } + + @Test + public void shouldPropagatePubSubResponseToOutput() throws Exception { + + Command command1 = new Command<>(CommandType.APPEND, new StatusOutput<>(new Utf8StringCodec()), + null); + + sut.channelRegistered(context); + sut.channelActive(context); + stack.add(command1); + + sut.channelRead(context, Unpooled.wrappedBuffer("*3\r\n$7\r\nmessage\r\n$3\r\nfoo\r\n$3\r\nbar\r\n".getBytes())); + + assertThat(command1.isDone()).isFalse(); + + verify(endpoint).notifyMessage(any()); + } + + @Test + public void shouldPropagateInterleavedPubSubResponseToOutput() throws Exception { + + Command command1 = new Command<>(CommandType.APPEND, new StatusOutput<>(new Utf8StringCodec()), + null); + Command command2 = new Command<>(CommandType.APPEND, new StatusOutput<>(new Utf8StringCodec()), + null); + + sut.channelRegistered(context); + sut.channelActive(context); + stack.add(command1); + stack.add(command2); + + sut.channelRead(context, Unpooled + .wrappedBuffer("+OK\r\n*4\r\n$8\r\npmessage\r\n$1\r\n*\r\n$3\r\nfoo\r\n$3\r\nbar\r\n+YEAH\r\n".getBytes())); + + assertThat(command1.get()).isEqualTo("OK"); + assertThat(command2.get()).isEqualTo("YEAH"); + + ArgumentCaptor captor = ArgumentCaptor.forClass(PubSubOutput.class); + verify(endpoint).notifyMessage(captor.capture()); + + assertThat(captor.getValue().pattern()).isEqualTo("*"); + assertThat(captor.getValue().channel()).isEqualTo("foo"); + assertThat(captor.getValue().get()).isEqualTo("bar"); + } }