From 0f72784eeae054136d41e18342b67cc7a86a6430 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Mon, 1 Oct 2018 15:10:19 +0200 Subject: [PATCH] Fix authentication on reconnect for subscribed Pub/Sub connections #868 Lettuce now checks whether the current connection has written a subscription command in addition to the registered channels/patterns. This allows a more meaningful checking of allowed commands so that regular commands can be written to the connection before resubscribing. This change allows authentication again for connections that got reconnected and had previously subscriptions. --- .../StatefulRedisPubSubConnectionImpl.java | 33 +++++++++++++++---- .../redis/pubsub/PubSubCommandTest.java | 31 ++++++++++++++--- 2 files changed, 54 insertions(+), 10 deletions(-) diff --git a/src/main/java/com/lambdaworks/redis/pubsub/StatefulRedisPubSubConnectionImpl.java b/src/main/java/com/lambdaworks/redis/pubsub/StatefulRedisPubSubConnectionImpl.java index 895ec44352..6b179b066c 100644 --- a/src/main/java/com/lambdaworks/redis/pubsub/StatefulRedisPubSubConnectionImpl.java +++ b/src/main/java/com/lambdaworks/redis/pubsub/StatefulRedisPubSubConnectionImpl.java @@ -47,10 +47,12 @@ public class StatefulRedisPubSubConnectionImpl extends StatefulRedisConnec StatefulRedisPubSubConnection { private static final Set ALLOWED_COMMANDS_SUBSCRIBED; + private static final Set SUBSCRIBE_COMMANDS; protected final List> listeners; protected final Set channels; protected final Set patterns; + private volatile boolean subscribeWritten = false; static { @@ -61,6 +63,11 @@ public class StatefulRedisPubSubConnectionImpl extends StatefulRedisConnec ALLOWED_COMMANDS_SUBSCRIBED.add(CommandType.UNSUBSCRIBE.name()); ALLOWED_COMMANDS_SUBSCRIBED.add(CommandType.PUNSUBSCRIBE.name()); ALLOWED_COMMANDS_SUBSCRIBED.add(CommandType.QUIT.name()); + + SUBSCRIBE_COMMANDS = new HashSet<>(2, 1); + + SUBSCRIBE_COMMANDS.add(CommandType.SUBSCRIBE.name()); + SUBSCRIBE_COMMANDS.add(CommandType.PSUBSCRIBE.name()); } /** @@ -146,6 +153,7 @@ public void channelRead(Object msg) { @Override public void activated() { + subscribeWritten = false; super.activated(); resubscribe(); } @@ -153,18 +161,31 @@ public void activated() { @Override public > C dispatch(C command) { - if (!channels.isEmpty() || !patterns.isEmpty()) { - - if (!ALLOWED_COMMANDS_SUBSCRIBED.contains(command.getType().name())) { + if (isSubscribed()) { + validateCommandAllowed(command); + } - throw new RedisException(String.format("Command %s not allowed while subscribed. Allowed commands are: %s", - command.getType().name(), ALLOWED_COMMANDS_SUBSCRIBED)); - } + if (!subscribeWritten && SUBSCRIBE_COMMANDS.contains(command.getType().name())) { + subscribeWritten = true; } + return super.dispatch(command); } + private static void validateCommandAllowed(RedisCommand command) { + + if (!ALLOWED_COMMANDS_SUBSCRIBED.contains(command.getType().name())) { + + throw new RedisException(String.format("Command %s not allowed while subscribed. Allowed commands are: %s", command + .getType().name(), ALLOWED_COMMANDS_SUBSCRIBED)); + } + } + + private boolean isSubscribed() { + return subscribeWritten && (!channels.isEmpty() || !patterns.isEmpty()); + } + /** * Re-subscribe to all previously subscribed channels and patterns. * diff --git a/src/test/java/com/lambdaworks/redis/pubsub/PubSubCommandTest.java b/src/test/java/com/lambdaworks/redis/pubsub/PubSubCommandTest.java index 339eb40329..59058988b6 100644 --- a/src/test/java/com/lambdaworks/redis/pubsub/PubSubCommandTest.java +++ b/src/test/java/com/lambdaworks/redis/pubsub/PubSubCommandTest.java @@ -21,16 +21,20 @@ import static org.hamcrest.CoreMatchers.hasItem; import static org.junit.Assert.assertThat; +import java.time.Duration; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.BlockingQueue; import java.util.concurrent.TimeUnit; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import org.junit.After; import org.junit.Before; import org.junit.Test; +import com.lambdaworks.Delay; import com.lambdaworks.TestClientResources; import com.lambdaworks.Wait; import com.lambdaworks.redis.*; @@ -91,20 +95,39 @@ public void authWithReconnect() { @Override protected void run(RedisClient client) throws Exception { + RedisPubSubAsyncCommands connection = client.connectPubSub().async(); - connection.addListener(PubSubCommandTest.this); + connection.getStatefulConnection().addListener(PubSubCommandTest.this); connection.auth(passwd); - connection.quit(); + connection.clientSetname("authWithReconnect"); + connection.subscribe(channel); + + assertThat(channels.take()).isEqualTo(channel); + + long id = findNamedClient("authWithReconnect"); + redis.clientKill(KillArgs.Builder.id(id)); - Thread.sleep(100); + Delay.delay(Duration.ofMillis(100)); Wait.untilTrue(connection::isOpen).waitOrTimeout(); - connection.subscribe(channel); assertThat(channels.take()).isEqualTo(channel); } }; } + private long findNamedClient(String name) { + + Pattern pattern = Pattern.compile(".*id=(\\d+).*name=" + name + ".*", Pattern.MULTILINE); + String clients = redis.clientList(); + Matcher matcher = pattern.matcher(clients); + + if (!matcher.find()) { + throw new IllegalStateException("Cannot find PubSub client in: " + clients); + } + + return Long.parseLong(matcher.group(1)); + } + @Test(timeout = 2000) public void message() throws Exception { pubsub.subscribe(channel);