Skip to content

Commit

Permalink
Add equality check for subscribed pattern and channel names #911
Browse files Browse the repository at this point in the history
Lettuce now stores pattern and channel names to which a connection is subscribed to with a comparison wrapper. This allows to compute hashCode and check for equality for built-in channel name types that do not support equals/hashCode for their actual content, in particular byte arrays.

Using byte[] for channel names prevented a proper equality check regarding the binary content and caused duplicates in the channel list. With every subscription, channel names were added in a quadraric amount at excessive memory cost.
  • Loading branch information
mp911de committed Nov 7, 2018
1 parent acb03f4 commit bfd0f18
Show file tree
Hide file tree
Showing 5 changed files with 229 additions and 43 deletions.
90 changes: 77 additions & 13 deletions src/main/java/io/lettuce/core/pubsub/PubSubEndpoint.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
*/
package io.lettuce.core.pubsub;

import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.*;
import java.util.concurrent.CopyOnWriteArrayList;

import io.lettuce.core.ClientOptions;
Expand All @@ -38,8 +35,8 @@ public class PubSubEndpoint<K, V> extends DefaultEndpoint {
private static final Set<String> ALLOWED_COMMANDS_SUBSCRIBED;
private static final Set<String> SUBSCRIBE_COMMANDS;
private final List<RedisPubSubListener<K, V>> listeners = new CopyOnWriteArrayList<>();
private final Set<K> channels;
private final Set<K> patterns;
private final Set<Wrapper<K>> channels;
private final Set<Wrapper<K>> patterns;
private volatile boolean subscribeWritten = false;

static {
Expand Down Expand Up @@ -94,12 +91,20 @@ protected List<RedisPubSubListener<K, V>> getListeners() {
return listeners;
}

public boolean hasChannelSubscriptions() {
return !channels.isEmpty();
}

public Set<K> getChannels() {
return channels;
return unwrap(this.channels);
}

public boolean hasPatternSubscriptions() {
return !patterns.isEmpty();
}

public Set<K> getPatterns() {
return patterns;
return unwrap(this.patterns);
}

@Override
Expand Down Expand Up @@ -151,7 +156,7 @@ private static void validateCommandAllowed(RedisCommand<?, ?, ?> command) {
}

private boolean isSubscribed() {
return subscribeWritten && (!channels.isEmpty() || !patterns.isEmpty());
return subscribeWritten && (hasChannelSubscriptions() || hasPatternSubscriptions());
}

public void notifyMessage(PubSubOutput<K, V, V> output) {
Expand Down Expand Up @@ -197,19 +202,78 @@ private void updateInternalState(PubSubOutput<K, V, V> output) {
// update internal state
switch (output.type()) {
case psubscribe:
patterns.add(output.pattern());
patterns.add(new Wrapper<>(output.pattern()));
break;
case punsubscribe:
patterns.remove(output.pattern());
patterns.remove(new Wrapper<>(output.pattern()));
break;
case subscribe:
channels.add(output.channel());
channels.add(new Wrapper<>(output.channel()));
break;
case unsubscribe:
channels.remove(output.channel());
channels.remove(new Wrapper<>(output.channel()));
break;
default:
break;
}
}

private Set<K> unwrap(Set<Wrapper<K>> wrapped) {

Set<K> result = new LinkedHashSet<>(wrapped.size());

for (Wrapper<K> channel : wrapped) {
result.add(channel.name);
}

return result;
}

/**
* Comparison/equality wrapper with specific {@code byte[]} equals and hashCode implementations.
*
* @param <K>
*/
static class Wrapper<K> {

protected final K name;

public Wrapper(K name) {
this.name = name;
}

@Override
public int hashCode() {

if (name instanceof byte[]) {
return Arrays.hashCode((byte[]) name);
}
return name.hashCode();
}

@Override
public boolean equals(Object obj) {

if (!(obj instanceof Wrapper)) {
return false;
}

Wrapper<K> that = (Wrapper<K>) obj;

if (name instanceof byte[] && that.name instanceof byte[]) {
return Arrays.equals((byte[]) name, (byte[]) that.name);
}

return name.equals(that.name);
}

@Override
public String toString() {
final StringBuffer sb = new StringBuffer();
sb.append(getClass().getSimpleName());
sb.append(" [name=").append(name);
sb.append(']');
return sb.toString();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ protected List<RedisFuture<Void>> resubscribe() {

List<RedisFuture<Void>> result = new ArrayList<>();

if (!endpoint.getChannels().isEmpty()) {
if (endpoint.hasChannelSubscriptions()) {
result.add(async().subscribe(toArray(endpoint.getChannels())));
}

if (!endpoint.getPatterns().isEmpty()) {
if (endpoint.hasPatternSubscriptions()) {
result.add(async().psubscribe(toArray(endpoint.getPatterns())));
}

Expand Down
52 changes: 52 additions & 0 deletions src/test/java/io/lettuce/core/ByteBufferCodec.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright 2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.lettuce.core;

import java.nio.ByteBuffer;

import io.lettuce.core.codec.RedisCodec;

/**
* @author Mark Paluch
*/
public class ByteBufferCodec implements RedisCodec<ByteBuffer, ByteBuffer> {

@Override
public ByteBuffer decodeKey(ByteBuffer bytes) {

ByteBuffer decoupled = ByteBuffer.allocate(bytes.remaining());
decoupled.put(bytes);
return (ByteBuffer) decoupled.flip();
}

@Override
public ByteBuffer decodeValue(ByteBuffer bytes) {

ByteBuffer decoupled = ByteBuffer.allocate(bytes.remaining());
decoupled.put(bytes);
return (ByteBuffer) decoupled.flip();
}

@Override
public ByteBuffer encodeKey(ByteBuffer key) {
return key.asReadOnlyBuffer();
}

@Override
public ByteBuffer encodeValue(ByteBuffer value) {
return value.asReadOnlyBuffer();
}
}
28 changes: 0 additions & 28 deletions src/test/java/io/lettuce/core/CustomCodecIntegrationTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -170,32 +170,4 @@ public ByteBuffer encodeValue(Object value) {
}
}

class ByteBufferCodec implements RedisCodec<ByteBuffer, ByteBuffer> {

@Override
public ByteBuffer decodeKey(ByteBuffer bytes) {

ByteBuffer decoupled = ByteBuffer.allocate(bytes.remaining());
decoupled.put(bytes);
return (ByteBuffer) decoupled.flip();
}

@Override
public ByteBuffer decodeValue(ByteBuffer bytes) {

ByteBuffer decoupled = ByteBuffer.allocate(bytes.remaining());
decoupled.put(bytes);
return (ByteBuffer) decoupled.flip();
}

@Override
public ByteBuffer encodeKey(ByteBuffer key) {
return key.asReadOnlyBuffer();
}

@Override
public ByteBuffer encodeValue(ByteBuffer value) {
return value.asReadOnlyBuffer();
}
}
}
98 changes: 98 additions & 0 deletions src/test/java/io/lettuce/core/pubsub/PubSubEndpointUnitTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright 2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.lettuce.core.pubsub;

import static org.assertj.core.api.Assertions.assertThat;

import java.nio.ByteBuffer;

import org.junit.jupiter.api.Test;

import io.lettuce.core.ByteBufferCodec;
import io.lettuce.core.ClientOptions;
import io.lettuce.core.codec.ByteArrayCodec;
import io.lettuce.core.codec.RedisCodec;
import io.lettuce.core.codec.StringCodec;
import io.lettuce.test.resource.TestClientResources;

/**
* Unit tests for {@link PubSubEndpoint}.
*
* @author Mark Paluch
*/
class PubSubEndpointUnitTests {

@Test
void shouldRetainUniqueChannelNames() {

PubSubEndpoint<String, String> sut = new PubSubEndpoint<>(ClientOptions.create(), TestClientResources.get());

sut.notifyMessage(createMessage("subscribe", "channel1", StringCodec.UTF8));
sut.notifyMessage(createMessage("subscribe", "channel1", StringCodec.UTF8));
sut.notifyMessage(createMessage("subscribe", "channel1", StringCodec.UTF8));
sut.notifyMessage(createMessage("subscribe", "channel2", StringCodec.UTF8));

assertThat(sut.getChannels()).hasSize(2).containsOnly("channel1", "channel2");
}

@Test
void shouldRetainUniqueBinaryChannelNames() {

PubSubEndpoint<byte[], byte[]> sut = new PubSubEndpoint<>(ClientOptions.create(), TestClientResources.get());

sut.notifyMessage(createMessage("subscribe", "channel1", ByteArrayCodec.INSTANCE));
sut.notifyMessage(createMessage("subscribe", "channel1", ByteArrayCodec.INSTANCE));
sut.notifyMessage(createMessage("subscribe", "channel1", ByteArrayCodec.INSTANCE));
sut.notifyMessage(createMessage("subscribe", "channel2", ByteArrayCodec.INSTANCE));

assertThat(sut.getChannels()).hasSize(2);
}

@Test
void shouldRetainUniqueByteBufferChannelNames() {

PubSubEndpoint<ByteBuffer, ByteBuffer> sut = new PubSubEndpoint<>(ClientOptions.create(), TestClientResources.get());

sut.notifyMessage(createMessage("subscribe", "channel1", new ByteBufferCodec()));
sut.notifyMessage(createMessage("subscribe", "channel1", new ByteBufferCodec()));
sut.notifyMessage(createMessage("subscribe", "channel1", new ByteBufferCodec()));
sut.notifyMessage(createMessage("subscribe", "channel2", new ByteBufferCodec()));

assertThat(sut.getChannels()).hasSize(2).containsOnly(ByteBuffer.wrap("channel1".getBytes()),
ByteBuffer.wrap("channel2".getBytes()));
}

@Test
void addsAndRemovesChannels() {

PubSubEndpoint<byte[], byte[]> sut = new PubSubEndpoint<>(ClientOptions.create(), TestClientResources.get());

sut.notifyMessage(createMessage("subscribe", "channel1", ByteArrayCodec.INSTANCE));
sut.notifyMessage(createMessage("unsubscribe", "channel1", ByteArrayCodec.INSTANCE));

assertThat(sut.getChannels()).isEmpty();
}

private static <K, V> PubSubOutput<K, V, V> createMessage(String action, String channel, RedisCodec<K, V> codec) {

PubSubOutput<K, V, V> output = new PubSubOutput<>(codec);

output.set(ByteBuffer.wrap(action.getBytes()));
output.set(ByteBuffer.wrap(channel.getBytes()));

return output;
}
}

0 comments on commit bfd0f18

Please sign in to comment.