Skip to content

Commit

Permalink
Fix authentication on reconnect for subscribed Pub/Sub connections #868
Browse files Browse the repository at this point in the history
Lettuce now checks whether the current connection has written a subscription command in addition to the registered channels/patterns. This allows a more meaningful checking of allowed commands so that regular commands can be written to the connection before resubscribing.

This change allows authentication again for connections that got reconnected and had previously subscriptions.
  • Loading branch information
mp911de committed Oct 1, 2018
1 parent 1bff314 commit 0f72784
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ public class StatefulRedisPubSubConnectionImpl<K, V> extends StatefulRedisConnec
StatefulRedisPubSubConnection<K, V> {

private static final Set<String> ALLOWED_COMMANDS_SUBSCRIBED;
private static final Set<String> SUBSCRIBE_COMMANDS;

protected final List<RedisPubSubListener<K, V>> listeners;
protected final Set<K> channels;
protected final Set<K> patterns;
private volatile boolean subscribeWritten = false;

static {

Expand All @@ -61,6 +63,11 @@ public class StatefulRedisPubSubConnectionImpl<K, V> extends StatefulRedisConnec
ALLOWED_COMMANDS_SUBSCRIBED.add(CommandType.UNSUBSCRIBE.name());
ALLOWED_COMMANDS_SUBSCRIBED.add(CommandType.PUNSUBSCRIBE.name());
ALLOWED_COMMANDS_SUBSCRIBED.add(CommandType.QUIT.name());

SUBSCRIBE_COMMANDS = new HashSet<>(2, 1);

SUBSCRIBE_COMMANDS.add(CommandType.SUBSCRIBE.name());
SUBSCRIBE_COMMANDS.add(CommandType.PSUBSCRIBE.name());
}

/**
Expand Down Expand Up @@ -146,25 +153,39 @@ public void channelRead(Object msg) {

@Override
public void activated() {
subscribeWritten = false;
super.activated();
resubscribe();
}

@Override
public <T, C extends RedisCommand<K, V, T>> C dispatch(C command) {

if (!channels.isEmpty() || !patterns.isEmpty()) {

if (!ALLOWED_COMMANDS_SUBSCRIBED.contains(command.getType().name())) {
if (isSubscribed()) {
validateCommandAllowed(command);
}

throw new RedisException(String.format("Command %s not allowed while subscribed. Allowed commands are: %s",
command.getType().name(), ALLOWED_COMMANDS_SUBSCRIBED));
}
if (!subscribeWritten && SUBSCRIBE_COMMANDS.contains(command.getType().name())) {
subscribeWritten = true;
}


return super.dispatch(command);
}

private static void validateCommandAllowed(RedisCommand<?, ?, ?> command) {

if (!ALLOWED_COMMANDS_SUBSCRIBED.contains(command.getType().name())) {

throw new RedisException(String.format("Command %s not allowed while subscribed. Allowed commands are: %s", command
.getType().name(), ALLOWED_COMMANDS_SUBSCRIBED));
}
}

private boolean isSubscribed() {
return subscribeWritten && (!channels.isEmpty() || !patterns.isEmpty());
}

/**
* Re-subscribe to all previously subscribed channels and patterns.
*
Expand Down
31 changes: 27 additions & 4 deletions src/test/java/com/lambdaworks/redis/pubsub/PubSubCommandTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,20 @@
import static org.hamcrest.CoreMatchers.hasItem;
import static org.junit.Assert.assertThat;

import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import com.lambdaworks.Delay;
import com.lambdaworks.TestClientResources;
import com.lambdaworks.Wait;
import com.lambdaworks.redis.*;
Expand Down Expand Up @@ -91,20 +95,39 @@ public void authWithReconnect() {
@Override
protected void run(RedisClient client) throws Exception {


RedisPubSubAsyncCommands<String, String> connection = client.connectPubSub().async();
connection.addListener(PubSubCommandTest.this);
connection.getStatefulConnection().addListener(PubSubCommandTest.this);
connection.auth(passwd);
connection.quit();
connection.clientSetname("authWithReconnect");
connection.subscribe(channel);

assertThat(channels.take()).isEqualTo(channel);

long id = findNamedClient("authWithReconnect");
redis.clientKill(KillArgs.Builder.id(id));

Thread.sleep(100);
Delay.delay(Duration.ofMillis(100));
Wait.untilTrue(connection::isOpen).waitOrTimeout();

connection.subscribe(channel);
assertThat(channels.take()).isEqualTo(channel);
}
};
}

private long findNamedClient(String name) {

Pattern pattern = Pattern.compile(".*id=(\\d+).*name=" + name + ".*", Pattern.MULTILINE);
String clients = redis.clientList();
Matcher matcher = pattern.matcher(clients);

if (!matcher.find()) {
throw new IllegalStateException("Cannot find PubSub client in: " + clients);
}

return Long.parseLong(matcher.group(1));
}

@Test(timeout = 2000)
public void message() throws Exception {
pubsub.subscribe(channel);
Expand Down

0 comments on commit 0f72784

Please sign in to comment.