Skip to content

Commit

Permalink
Guard Pub/Sub callbacks against exceptions #997
Browse files Browse the repository at this point in the history
Pub/Sub listener callbacks are now guarded against exceptions bubbling up into channel processing. Instead, exceptions are logged. Listener notification stops on the first exception.

These guards prevent exceptions interrupting the state update flow which could previously cause the state machine of decoding leave in an invalid state.
  • Loading branch information
mp911de committed Mar 14, 2019
1 parent d30ce30 commit 54a2912
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 5 deletions.
20 changes: 16 additions & 4 deletions src/main/java/io/lettuce/core/pubsub/PubSubCommandHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;

/**
* A netty {@link ChannelHandler} responsible for writing Redis Pub/Sub commands and reading the response stream from the
Expand All @@ -45,6 +47,8 @@
*/
public class PubSubCommandHandler<K, V> extends CommandHandler {

private static final InternalLogger logger = InternalLoggerFactory.getInstance(PubSubCommandHandler.class);

private final PubSubEndpoint<K, V> endpoint;
private final RedisCodec<K, V> codec;
private final Deque<ReplayOutput<K, V>> queue = new ArrayDeque<>();
Expand Down Expand Up @@ -91,7 +95,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf buffer) throws Interrup

RedisCommand<?, ?, ?> peek = getStack().peek();
canComplete(peek);
endpoint.notifyMessage(output);
doNotifyMessage(output);
output = new PubSubOutput<>(codec);
}

Expand All @@ -103,7 +107,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf buffer) throws Interrup
while ((replay = queue.poll()) != null) {

replay.replay(output);
endpoint.notifyMessage(output);
doNotifyMessage(output);
output = new PubSubOutput<>(codec);
}

Expand All @@ -113,7 +117,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf buffer) throws Interrup
return;
}

endpoint.notifyMessage(output);
doNotifyMessage(output);
output = new PubSubOutput<>(codec);
}

Expand Down Expand Up @@ -201,7 +205,15 @@ private static boolean isPubSubMessage(ResponseHeaderReplayOutput<?, ?> replay)
protected void afterDecode(ChannelHandlerContext ctx, RedisCommand<?, ?, ?> command) {

if (command.getOutput() instanceof PubSubOutput) {
endpoint.notifyMessage((PubSubOutput) command.getOutput());
doNotifyMessage((PubSubOutput) command.getOutput());
}
}

private void doNotifyMessage(PubSubOutput<K, V, V> output) {
try {
endpoint.notifyMessage(output);
} catch (Exception e) {
logger.error("Unexpected error occurred in PubSubEndpoint.notifyMessage", e);
}
}

Expand Down
9 changes: 8 additions & 1 deletion src/main/java/io/lettuce/core/pubsub/PubSubEndpoint.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,15 @@
import io.lettuce.core.resource.ClientResources;
import io.netty.channel.Channel;
import io.netty.util.internal.ConcurrentSet;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;

/**
* @author Mark Paluch
*/
public class PubSubEndpoint<K, V> extends DefaultEndpoint {

private static final InternalLogger logger = InternalLoggerFactory.getInstance(PubSubEndpoint.class);
private static final Set<String> ALLOWED_COMMANDS_SUBSCRIBED;
private static final Set<String> SUBSCRIBE_COMMANDS;
private final List<RedisPubSubListener<K, V>> listeners = new CopyOnWriteArrayList<>();
Expand Down Expand Up @@ -167,7 +170,11 @@ public void notifyMessage(PubSubOutput<K, V, V> output) {
}

updateInternalState(output);
notifyListeners(output);
try {
notifyListeners(output);
} catch (Exception e) {
logger.error("Unexpected error occurred in RedisPubSubListener callback", e);
}
}

protected void notifyListeners(PubSubOutput<K, V, V> output) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
Expand All @@ -31,6 +32,7 @@
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;
import org.mockito.stubbing.Answer;
import org.springframework.test.util.ReflectionTestUtils;

import io.lettuce.core.ClientOptions;
Expand All @@ -51,6 +53,7 @@

/**
* @author Mark Paluch
* @author Giridhar Kannan
*/
@ExtendWith(MockitoExtension.class)
@MockitoSettings(strictness = Strictness.LENIENT)
Expand Down Expand Up @@ -264,6 +267,36 @@ void shouldCompleteWithChunkedResponseOutOfBand() throws Exception {
assertThat(captor.getAllValues().get(1).channel()).isEqualTo("b");
}

@Test
void shouldCompleteUnsubscribe() throws Exception {

Command<String, String, String> subCmd = new Command<>(CommandType.SUBSCRIBE,
new PubSubOutput<>(new Utf8StringCodec()), null);
Command<String, String, String> unSubCmd = new Command<>(CommandType.UNSUBSCRIBE, new PubSubOutput<>(
new Utf8StringCodec()), null);

doAnswer((Answer<PubSubEndpoint<String, String>>) inv -> {
PubSubOutput<String, String, String> out = inv.getArgument(0);
if (out.type() == PubSubOutput.Type.message) {
throw new NullPointerException();
}
return endpoint;
}).when(endpoint).notifyMessage(any());

sut.channelRegistered(context);
sut.channelActive(context);

stack.add(subCmd);
stack.add(unSubCmd);
ByteBuf buf = responseBytes("*3\r\n$9\r\nsubscribe\r\n$10\r\ntest_sub_0\r\n:1\r\n"
+ "*3\r\n$7\r\nmessage\r\n$10\r\ntest_sub_0\r\n$3\r\nabc\r\n"
+ "*3\r\n$11\r\nunsubscribe\r\n$10\r\ntest_sub_0\r\n:0\r\n");
sut.channelRead(context, buf);
sut.channelRead(context, responseBytes("*3\r\n$7\r\nmessage\r\n$10\r\ntest_sub_1\r\n$3\r\nabc\r\n"));

assertThat(unSubCmd.isDone()).isTrue();
}

@Test
void shouldCompleteWithChunkedResponseInterleavedSending() throws Exception {

Expand Down
29 changes: 29 additions & 0 deletions src/test/java/io/lettuce/core/pubsub/PubSubEndpointUnitTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import static org.assertj.core.api.Assertions.assertThat;

import java.nio.ByteBuffer;
import java.util.concurrent.atomic.AtomicInteger;

import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -86,6 +87,34 @@ void addsAndRemovesChannels() {
assertThat(sut.getChannels()).isEmpty();
}

@Test
void listenerNotificationShouldFailGracefully() {

PubSubEndpoint<byte[], byte[]> sut = new PubSubEndpoint<>(ClientOptions.create(), TestClientResources.get());

AtomicInteger notified = new AtomicInteger();

sut.addListener(new RedisPubSubAdapter<byte[], byte[]>() {
@Override
public void message(byte[] channel, byte[] message) {

notified.incrementAndGet();
throw new UnsupportedOperationException();
}
});

sut.addListener(new RedisPubSubAdapter<byte[], byte[]>() {
@Override
public void message(byte[] channel, byte[] message) {
notified.incrementAndGet();
}
});

sut.notifyMessage(createMessage("message", "channel1", ByteArrayCodec.INSTANCE));

assertThat(notified).hasValue(1);
}

private static <K, V> PubSubOutput<K, V, V> createMessage(String action, String channel, RedisCodec<K, V> codec) {

PubSubOutput<K, V, V> output = new PubSubOutput<>(codec);
Expand Down

0 comments on commit 54a2912

Please sign in to comment.