diff --git a/src/main/java/io/lettuce/core/pubsub/PubSubEndpoint.java b/src/main/java/io/lettuce/core/pubsub/PubSubEndpoint.java index 8d884445c5..e93fed80c2 100644 --- a/src/main/java/io/lettuce/core/pubsub/PubSubEndpoint.java +++ b/src/main/java/io/lettuce/core/pubsub/PubSubEndpoint.java @@ -1,5 +1,5 @@ /* - * Copyright 2011-2016 the original author or authors. + * Copyright 2011-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,13 +15,16 @@ */ package io.lettuce.core.pubsub; +import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.concurrent.CopyOnWriteArrayList; import io.lettuce.core.ClientOptions; +import io.lettuce.core.RedisException; +import io.lettuce.core.protocol.CommandType; import io.lettuce.core.protocol.DefaultEndpoint; - +import io.lettuce.core.protocol.RedisCommand; import io.netty.util.internal.ConcurrentSet; /** @@ -29,10 +32,22 @@ */ public class PubSubEndpoint extends DefaultEndpoint { + private static final Set ALLOWED_COMMANDS_SUBSCRIBED; private final List> listeners = new CopyOnWriteArrayList<>(); private final Set channels; private final Set patterns; + static { + + ALLOWED_COMMANDS_SUBSCRIBED = new HashSet<>(5, 1); + + ALLOWED_COMMANDS_SUBSCRIBED.add(CommandType.SUBSCRIBE.name()); + ALLOWED_COMMANDS_SUBSCRIBED.add(CommandType.PSUBSCRIBE.name()); + ALLOWED_COMMANDS_SUBSCRIBED.add(CommandType.UNSUBSCRIBE.name()); + ALLOWED_COMMANDS_SUBSCRIBED.add(CommandType.PUNSUBSCRIBE.name()); + ALLOWED_COMMANDS_SUBSCRIBED.add(CommandType.QUIT.name()); + } + /** * Initialize a new instance that handles commands from the supplied queue. * @@ -75,6 +90,21 @@ public Set getPatterns() { return patterns; } + @Override + public RedisCommand write(RedisCommand command) { + + if (!channels.isEmpty() || !patterns.isEmpty()) { + + if (!ALLOWED_COMMANDS_SUBSCRIBED.contains(command.getType().name())) { + + throw new RedisException(String.format("Command %s not allowed while subscribed. Allowed commands are: %s", + command.getType().name(), ALLOWED_COMMANDS_SUBSCRIBED)); + } + } + + return super.write(command); + } + public void notifyMessage(PubSubOutput output) { // drop empty messages diff --git a/src/test/java/io/lettuce/SslTest.java b/src/test/java/io/lettuce/SslTest.java index b92efc148e..132ccf6dea 100644 --- a/src/test/java/io/lettuce/SslTest.java +++ b/src/test/java/io/lettuce/SslTest.java @@ -261,10 +261,10 @@ public void pubSubSslAndBreakConnection() throws Exception { assertThat(future.get()).doesNotContain("c1", "c2"); assertThat(future.isDone()).isEqualTo(true); - RedisFuture> defectFuture = connection.pubsubChannels(); + RedisFuture defectFuture = connection.subscribe("foo"); try { - assertThat(defectFuture.get()).doesNotContain("c1", "c2"); + defectFuture.get(); fail("Missing ExecutionException with nested SSLHandshakeException"); } catch (InterruptedException e) { fail("Missing ExecutionException with nested SSLHandshakeException"); diff --git a/src/test/java/io/lettuce/core/pubsub/PubSubCommandTest.java b/src/test/java/io/lettuce/core/pubsub/PubSubCommandTest.java index 809df3b36d..177d734530 100644 --- a/src/test/java/io/lettuce/core/pubsub/PubSubCommandTest.java +++ b/src/test/java/io/lettuce/core/pubsub/PubSubCommandTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2011-2016 the original author or authors. + * Copyright 2011-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package io.lettuce.core.pubsub; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Fail.fail; import static org.hamcrest.CoreMatchers.hasItem; import static org.junit.Assert.assertThat; @@ -382,6 +383,19 @@ public void removeListener() throws Exception { assertThat(messages.poll(10, TimeUnit.MILLISECONDS)).isNull(); } + @Test + public void pingNotAllowedInSubscriptionState() throws Exception { + + pubsub.subscribe(channel).get(); + + assertThatThrownBy(() -> pubsub.ping().get()).isInstanceOf(RedisException.class).hasMessageContaining("not allowed"); + pubsub.unsubscribe(channel); + + Wait.untilTrue(() -> channels.size() == 2).waitOrTimeout(); + + assertThat(pubsub.ping().get()).isEqualTo("PONG"); + } + // RedisPubSubListener implementation @Override