Skip to content

Commit

Permalink
Polishing #2594
Browse files Browse the repository at this point in the history
Update tests to reflect new behavior. Use negotiated protocol version and fall back to the configured one of no negotiated version is available.

Original pull request: #2778
  • Loading branch information
mp911de committed Mar 13, 2024
1 parent 73cb832 commit 2c5f59e
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 27 deletions.
16 changes: 15 additions & 1 deletion src/main/java/io/lettuce/core/pubsub/PubSubEndpoint.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.concurrent.CopyOnWriteArrayList;

import io.lettuce.core.ClientOptions;
import io.lettuce.core.ConnectionState;
import io.lettuce.core.RedisException;
import io.lettuce.core.protocol.CommandType;
import io.lettuce.core.protocol.DefaultEndpoint;
Expand Down Expand Up @@ -55,6 +56,8 @@ public class PubSubEndpoint<K, V> extends DefaultEndpoint {

private volatile boolean subscribeWritten = false;

private ConnectionState connectionState;

static {

ALLOWED_COMMANDS_SUBSCRIBED = new HashSet<>(6, 1);
Expand Down Expand Up @@ -195,13 +198,24 @@ protected boolean containsViolatingCommands(Collection<? extends RedisCommand<?,
}

private boolean isAllowed(RedisCommand<?, ?, ?> command) {
return getProtocolVersion() == ProtocolVersion.RESP3 || ALLOWED_COMMANDS_SUBSCRIBED.contains(command.getType().name());

ProtocolVersion protocolVersion = connectionState != null ? connectionState.getNegotiatedProtocolVersion() : null;

if (protocolVersion == null) {
protocolVersion = getProtocolVersion();
}

return protocolVersion == ProtocolVersion.RESP3 || ALLOWED_COMMANDS_SUBSCRIBED.contains(command.getType().name());
}

public boolean isSubscribed() {
return subscribeWritten && (hasChannelSubscriptions() || hasPatternSubscriptions());
}

void setConnectionState(ConnectionState connectionState) {
this.connectionState = connectionState;
}

void notifyMessage(PubSubMessage<K, V> message) {

// drop empty messages
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ public StatefulRedisPubSubConnectionImpl(PubSubEndpoint<K, V> endpoint, RedisCha
Duration timeout) {

super(writer, endpoint, codec, timeout);

this.endpoint = endpoint;
endpoint.setConnectionState(getConnectionState());
}

/**
Expand Down
29 changes: 28 additions & 1 deletion src/test/java/io/lettuce/core/pubsub/PubSubCommandResp2Test.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,16 @@
*/
package io.lettuce.core.pubsub;

import static org.assertj.core.api.Assertions.*;

import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import io.lettuce.core.ClientOptions;
import io.lettuce.core.RedisException;
import io.lettuce.core.protocol.ProtocolVersion;
import io.lettuce.test.TestFutures;
import io.lettuce.test.Wait;

/**
* Pub/Sub Command tests using RESP2.
Expand All @@ -28,13 +33,35 @@
*/
class PubSubCommandResp2Test extends PubSubCommandTest {

@Override
protected ClientOptions getOptions() {
return ClientOptions.builder().protocolVersion(ProtocolVersion.RESP2).build();
}

@Override
@Test
@Disabled("Push messages are not available with RESP2")
@Override
void messageAsPushMessage() {
}

@Test
@Disabled("Does not apply with RESP2")
@Override
void echoAllowedInSubscriptionState() {
}

@Test
void echoNotAllowedInSubscriptionState() {

TestFutures.awaitOrTimeout(pubsub.subscribe(channel));

assertThatThrownBy(() -> TestFutures.getOrTimeout(pubsub.echo("ping"))).isInstanceOf(RedisException.class)
.hasMessageContaining("not allowed");
pubsub.unsubscribe(channel);

Wait.untilTrue(() -> channels.size() == 2).waitOrTimeout();

assertThat(TestFutures.getOrTimeout(pubsub.echo("ping"))).isEqualTo("ping");
}

}
41 changes: 19 additions & 22 deletions src/test/java/io/lettuce/core/pubsub/PubSubCommandTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,7 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import io.lettuce.core.AbstractRedisClientTest;
import io.lettuce.core.ClientOptions;
import io.lettuce.core.KillArgs;
import io.lettuce.core.RedisClient;
import io.lettuce.core.RedisException;
import io.lettuce.core.RedisFuture;
import io.lettuce.core.RedisURI;
import io.lettuce.core.*;
import io.lettuce.core.api.async.RedisAsyncCommands;
import io.lettuce.core.api.push.PushMessage;
import io.lettuce.core.internal.LettuceFactories;
Expand All @@ -63,17 +57,22 @@
*/
class PubSubCommandTest extends AbstractRedisClientTest implements RedisPubSubListener<String, String> {

private RedisPubSubAsyncCommands<String, String> pubsub;
RedisPubSubAsyncCommands<String, String> pubsub;

private BlockingQueue<String> channels;
private BlockingQueue<String> patterns;
private BlockingQueue<String> messages;
private BlockingQueue<Long> counts;
BlockingQueue<String> channels;

private String channel = "channel0";
private String shardChannel = "shard-channel";
BlockingQueue<String> patterns;

BlockingQueue<String> messages;

BlockingQueue<Long> counts;

String channel = "channel0";

String shardChannel = "shard-channel";
private String pattern = "channel*";
private String message = "msg!";

String message = "msg!";

@BeforeEach
void openPubSubConnection() {
Expand Down Expand Up @@ -464,6 +463,7 @@ void adapter() throws Exception {
final BlockingQueue<Long> localCounts = LettuceFactories.newBlockingQueue();

RedisPubSubAdapter<String, String> adapter = new RedisPubSubAdapter<String, String>() {

@Override
public void subscribed(String channel, long count) {
super.subscribed(channel, count);
Expand All @@ -475,6 +475,7 @@ public void unsubscribed(String channel, long count) {
super.unsubscribed(channel, count);
localCounts.add(count);
}

};

pubsub.getStatefulConnection().addListener(adapter);
Expand Down Expand Up @@ -507,17 +508,12 @@ void removeListener() throws Exception {
}

@Test
void pingNotAllowedInSubscriptionState() {
void echoAllowedInSubscriptionState() {

TestFutures.awaitOrTimeout(pubsub.subscribe(channel));

assertThatThrownBy(() -> TestFutures.getOrTimeout(pubsub.echo("ping"))).isInstanceOf(RedisException.class)
.hasMessageContaining("not allowed");
pubsub.unsubscribe(channel);

Wait.untilTrue(() -> channels.size() == 2).waitOrTimeout();

assertThat(TestFutures.getOrTimeout(pubsub.echo("ping"))).isEqualTo("ping");
pubsub.unsubscribe(channel);
}

// RedisPubSubListener implementation
Expand Down Expand Up @@ -558,4 +554,5 @@ public void punsubscribed(String pattern, long count) {
patterns.add(pattern);
counts.add(count);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import io.lettuce.test.resource.FastShutdown;
import io.lettuce.test.settings.TestSettings;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.simple.SimpleMeterRegistry;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.tracing.exporter.FinishedSpan;
import io.micrometer.tracing.test.SampleTestRunner;
Expand Down Expand Up @@ -83,7 +82,6 @@ public SampleTestRunnerConsumer yourCode() {
FastShutdown.shutdown(clientResources);

assertThat(tracer.getFinishedSpans()).isNotEmpty();
System.out.println(((SimpleMeterRegistry) meterRegistry).getMetersAsString());

assertThat(tracer.getFinishedSpans()).isNotEmpty();

Expand Down

0 comments on commit 2c5f59e

Please sign in to comment.