diff --git a/src/main/java/com/lambdaworks/redis/pubsub/StatefulRedisPubSubConnectionImpl.java b/src/main/java/com/lambdaworks/redis/pubsub/StatefulRedisPubSubConnectionImpl.java index 895ec44352..6b179b066c 100644 --- a/src/main/java/com/lambdaworks/redis/pubsub/StatefulRedisPubSubConnectionImpl.java +++ b/src/main/java/com/lambdaworks/redis/pubsub/StatefulRedisPubSubConnectionImpl.java @@ -47,10 +47,12 @@ public class StatefulRedisPubSubConnectionImpl extends StatefulRedisConnec StatefulRedisPubSubConnection { private static final Set ALLOWED_COMMANDS_SUBSCRIBED; + private static final Set SUBSCRIBE_COMMANDS; protected final List> listeners; protected final Set channels; protected final Set patterns; + private volatile boolean subscribeWritten = false; static { @@ -61,6 +63,11 @@ public class StatefulRedisPubSubConnectionImpl extends StatefulRedisConnec ALLOWED_COMMANDS_SUBSCRIBED.add(CommandType.UNSUBSCRIBE.name()); ALLOWED_COMMANDS_SUBSCRIBED.add(CommandType.PUNSUBSCRIBE.name()); ALLOWED_COMMANDS_SUBSCRIBED.add(CommandType.QUIT.name()); + + SUBSCRIBE_COMMANDS = new HashSet<>(2, 1); + + SUBSCRIBE_COMMANDS.add(CommandType.SUBSCRIBE.name()); + SUBSCRIBE_COMMANDS.add(CommandType.PSUBSCRIBE.name()); } /** @@ -146,6 +153,7 @@ public void channelRead(Object msg) { @Override public void activated() { + subscribeWritten = false; super.activated(); resubscribe(); } @@ -153,18 +161,31 @@ public void activated() { @Override public > C dispatch(C command) { - if (!channels.isEmpty() || !patterns.isEmpty()) { - - if (!ALLOWED_COMMANDS_SUBSCRIBED.contains(command.getType().name())) { + if (isSubscribed()) { + validateCommandAllowed(command); + } - throw new RedisException(String.format("Command %s not allowed while subscribed. Allowed commands are: %s", - command.getType().name(), ALLOWED_COMMANDS_SUBSCRIBED)); - } + if (!subscribeWritten && SUBSCRIBE_COMMANDS.contains(command.getType().name())) { + subscribeWritten = true; } + return super.dispatch(command); } + private static void validateCommandAllowed(RedisCommand command) { + + if (!ALLOWED_COMMANDS_SUBSCRIBED.contains(command.getType().name())) { + + throw new RedisException(String.format("Command %s not allowed while subscribed. Allowed commands are: %s", command + .getType().name(), ALLOWED_COMMANDS_SUBSCRIBED)); + } + } + + private boolean isSubscribed() { + return subscribeWritten && (!channels.isEmpty() || !patterns.isEmpty()); + } + /** * Re-subscribe to all previously subscribed channels and patterns. * diff --git a/src/test/java/com/lambdaworks/redis/pubsub/PubSubCommandTest.java b/src/test/java/com/lambdaworks/redis/pubsub/PubSubCommandTest.java index 339eb40329..59058988b6 100644 --- a/src/test/java/com/lambdaworks/redis/pubsub/PubSubCommandTest.java +++ b/src/test/java/com/lambdaworks/redis/pubsub/PubSubCommandTest.java @@ -21,16 +21,20 @@ import static org.hamcrest.CoreMatchers.hasItem; import static org.junit.Assert.assertThat; +import java.time.Duration; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.BlockingQueue; import java.util.concurrent.TimeUnit; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import org.junit.After; import org.junit.Before; import org.junit.Test; +import com.lambdaworks.Delay; import com.lambdaworks.TestClientResources; import com.lambdaworks.Wait; import com.lambdaworks.redis.*; @@ -91,20 +95,39 @@ public void authWithReconnect() { @Override protected void run(RedisClient client) throws Exception { + RedisPubSubAsyncCommands connection = client.connectPubSub().async(); - connection.addListener(PubSubCommandTest.this); + connection.getStatefulConnection().addListener(PubSubCommandTest.this); connection.auth(passwd); - connection.quit(); + connection.clientSetname("authWithReconnect"); + connection.subscribe(channel); + + assertThat(channels.take()).isEqualTo(channel); + + long id = findNamedClient("authWithReconnect"); + redis.clientKill(KillArgs.Builder.id(id)); - Thread.sleep(100); + Delay.delay(Duration.ofMillis(100)); Wait.untilTrue(connection::isOpen).waitOrTimeout(); - connection.subscribe(channel); assertThat(channels.take()).isEqualTo(channel); } }; } + private long findNamedClient(String name) { + + Pattern pattern = Pattern.compile(".*id=(\\d+).*name=" + name + ".*", Pattern.MULTILINE); + String clients = redis.clientList(); + Matcher matcher = pattern.matcher(clients); + + if (!matcher.find()) { + throw new IllegalStateException("Cannot find PubSub client in: " + clients); + } + + return Long.parseLong(matcher.group(1)); + } + @Test(timeout = 2000) public void message() throws Exception { pubsub.subscribe(channel);