diff --git a/src/main/java/io/lettuce/core/pubsub/PubSubEndpoint.java b/src/main/java/io/lettuce/core/pubsub/PubSubEndpoint.java index 89eaa73a58..2ccbaf6a5d 100644 --- a/src/main/java/io/lettuce/core/pubsub/PubSubEndpoint.java +++ b/src/main/java/io/lettuce/core/pubsub/PubSubEndpoint.java @@ -25,6 +25,7 @@ import java.util.concurrent.CopyOnWriteArrayList; import io.lettuce.core.ClientOptions; +import io.lettuce.core.ConnectionState; import io.lettuce.core.RedisException; import io.lettuce.core.protocol.CommandType; import io.lettuce.core.protocol.DefaultEndpoint; @@ -55,6 +56,8 @@ public class PubSubEndpoint extends DefaultEndpoint { private volatile boolean subscribeWritten = false; + private ConnectionState connectionState; + static { ALLOWED_COMMANDS_SUBSCRIBED = new HashSet<>(6, 1); @@ -195,13 +198,24 @@ protected boolean containsViolatingCommands(Collection command) { - return getProtocolVersion() == ProtocolVersion.RESP3 || ALLOWED_COMMANDS_SUBSCRIBED.contains(command.getType().name()); + + ProtocolVersion protocolVersion = connectionState != null ? connectionState.getNegotiatedProtocolVersion() : null; + + if (protocolVersion == null) { + protocolVersion = getProtocolVersion(); + } + + return protocolVersion == ProtocolVersion.RESP3 || ALLOWED_COMMANDS_SUBSCRIBED.contains(command.getType().name()); } public boolean isSubscribed() { return subscribeWritten && (hasChannelSubscriptions() || hasPatternSubscriptions()); } + void setConnectionState(ConnectionState connectionState) { + this.connectionState = connectionState; + } + void notifyMessage(PubSubMessage message) { // drop empty messages diff --git a/src/main/java/io/lettuce/core/pubsub/StatefulRedisPubSubConnectionImpl.java b/src/main/java/io/lettuce/core/pubsub/StatefulRedisPubSubConnectionImpl.java index 966bc76de6..5d55e4d682 100644 --- a/src/main/java/io/lettuce/core/pubsub/StatefulRedisPubSubConnectionImpl.java +++ b/src/main/java/io/lettuce/core/pubsub/StatefulRedisPubSubConnectionImpl.java @@ -59,8 +59,8 @@ public StatefulRedisPubSubConnectionImpl(PubSubEndpoint endpoint, RedisCha Duration timeout) { super(writer, endpoint, codec, timeout); - this.endpoint = endpoint; + endpoint.setConnectionState(getConnectionState()); } /** diff --git a/src/test/java/io/lettuce/core/pubsub/PubSubCommandResp2Test.java b/src/test/java/io/lettuce/core/pubsub/PubSubCommandResp2Test.java index ab9c5e3132..7a3b171e97 100644 --- a/src/test/java/io/lettuce/core/pubsub/PubSubCommandResp2Test.java +++ b/src/test/java/io/lettuce/core/pubsub/PubSubCommandResp2Test.java @@ -15,11 +15,16 @@ */ package io.lettuce.core.pubsub; +import static org.assertj.core.api.Assertions.*; + import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import io.lettuce.core.ClientOptions; +import io.lettuce.core.RedisException; import io.lettuce.core.protocol.ProtocolVersion; +import io.lettuce.test.TestFutures; +import io.lettuce.test.Wait; /** * Pub/Sub Command tests using RESP2. @@ -28,13 +33,35 @@ */ class PubSubCommandResp2Test extends PubSubCommandTest { + @Override protected ClientOptions getOptions() { return ClientOptions.builder().protocolVersion(ProtocolVersion.RESP2).build(); } - @Override @Test @Disabled("Push messages are not available with RESP2") + @Override void messageAsPushMessage() { } + + @Test + @Disabled("Does not apply with RESP2") + @Override + void echoAllowedInSubscriptionState() { + } + + @Test + void echoNotAllowedInSubscriptionState() { + + TestFutures.awaitOrTimeout(pubsub.subscribe(channel)); + + assertThatThrownBy(() -> TestFutures.getOrTimeout(pubsub.echo("ping"))).isInstanceOf(RedisException.class) + .hasMessageContaining("not allowed"); + pubsub.unsubscribe(channel); + + Wait.untilTrue(() -> channels.size() == 2).waitOrTimeout(); + + assertThat(TestFutures.getOrTimeout(pubsub.echo("ping"))).isEqualTo("ping"); + } + } diff --git a/src/test/java/io/lettuce/core/pubsub/PubSubCommandTest.java b/src/test/java/io/lettuce/core/pubsub/PubSubCommandTest.java index 84c4c8d886..ddc07ed421 100644 --- a/src/test/java/io/lettuce/core/pubsub/PubSubCommandTest.java +++ b/src/test/java/io/lettuce/core/pubsub/PubSubCommandTest.java @@ -32,13 +32,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import io.lettuce.core.AbstractRedisClientTest; -import io.lettuce.core.ClientOptions; -import io.lettuce.core.KillArgs; -import io.lettuce.core.RedisClient; -import io.lettuce.core.RedisException; -import io.lettuce.core.RedisFuture; -import io.lettuce.core.RedisURI; +import io.lettuce.core.*; import io.lettuce.core.api.async.RedisAsyncCommands; import io.lettuce.core.api.push.PushMessage; import io.lettuce.core.internal.LettuceFactories; @@ -63,17 +57,22 @@ */ class PubSubCommandTest extends AbstractRedisClientTest implements RedisPubSubListener { - private RedisPubSubAsyncCommands pubsub; + RedisPubSubAsyncCommands pubsub; - private BlockingQueue channels; - private BlockingQueue patterns; - private BlockingQueue messages; - private BlockingQueue counts; + BlockingQueue channels; - private String channel = "channel0"; - private String shardChannel = "shard-channel"; + BlockingQueue patterns; + + BlockingQueue messages; + + BlockingQueue counts; + + String channel = "channel0"; + + String shardChannel = "shard-channel"; private String pattern = "channel*"; - private String message = "msg!"; + + String message = "msg!"; @BeforeEach void openPubSubConnection() { @@ -464,6 +463,7 @@ void adapter() throws Exception { final BlockingQueue localCounts = LettuceFactories.newBlockingQueue(); RedisPubSubAdapter adapter = new RedisPubSubAdapter() { + @Override public void subscribed(String channel, long count) { super.subscribed(channel, count); @@ -475,6 +475,7 @@ public void unsubscribed(String channel, long count) { super.unsubscribed(channel, count); localCounts.add(count); } + }; pubsub.getStatefulConnection().addListener(adapter); @@ -507,17 +508,12 @@ void removeListener() throws Exception { } @Test - void pingNotAllowedInSubscriptionState() { + void echoAllowedInSubscriptionState() { TestFutures.awaitOrTimeout(pubsub.subscribe(channel)); - assertThatThrownBy(() -> TestFutures.getOrTimeout(pubsub.echo("ping"))).isInstanceOf(RedisException.class) - .hasMessageContaining("not allowed"); - pubsub.unsubscribe(channel); - - Wait.untilTrue(() -> channels.size() == 2).waitOrTimeout(); - assertThat(TestFutures.getOrTimeout(pubsub.echo("ping"))).isEqualTo("ping"); + pubsub.unsubscribe(channel); } // RedisPubSubListener implementation @@ -558,4 +554,5 @@ public void punsubscribed(String pattern, long count) { patterns.add(pattern); counts.add(count); } + } diff --git a/src/test/java/io/lettuce/core/tracing/SynchronousIntegrationTests.java b/src/test/java/io/lettuce/core/tracing/SynchronousIntegrationTests.java index 76ae16f210..12febc72dd 100644 --- a/src/test/java/io/lettuce/core/tracing/SynchronousIntegrationTests.java +++ b/src/test/java/io/lettuce/core/tracing/SynchronousIntegrationTests.java @@ -28,7 +28,6 @@ import io.lettuce.test.resource.FastShutdown; import io.lettuce.test.settings.TestSettings; import io.micrometer.core.instrument.MeterRegistry; -import io.micrometer.core.instrument.simple.SimpleMeterRegistry; import io.micrometer.observation.ObservationRegistry; import io.micrometer.tracing.exporter.FinishedSpan; import io.micrometer.tracing.test.SampleTestRunner; @@ -83,7 +82,6 @@ public SampleTestRunnerConsumer yourCode() { FastShutdown.shutdown(clientResources); assertThat(tracer.getFinishedSpans()).isNotEmpty(); - System.out.println(((SimpleMeterRegistry) meterRegistry).getMetersAsString()); assertThat(tracer.getFinishedSpans()).isNotEmpty();