From bfd0f18d3e23e3cd226f2ddf44704f75e5034d8f Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Wed, 7 Nov 2018 15:09:33 +0100 Subject: [PATCH] Add equality check for subscribed pattern and channel names #911 Lettuce now stores pattern and channel names to which a connection is subscribed to with a comparison wrapper. This allows to compute hashCode and check for equality for built-in channel name types that do not support equals/hashCode for their actual content, in particular byte arrays. Using byte[] for channel names prevented a proper equality check regarding the binary content and caused duplicates in the channel list. With every subscription, channel names were added in a quadraric amount at excessive memory cost. --- .../lettuce/core/pubsub/PubSubEndpoint.java | 90 ++++++++++++++--- .../StatefulRedisPubSubConnectionImpl.java | 4 +- .../java/io/lettuce/core/ByteBufferCodec.java | 52 ++++++++++ .../core/CustomCodecIntegrationTests.java | 28 ------ .../core/pubsub/PubSubEndpointUnitTests.java | 98 +++++++++++++++++++ 5 files changed, 229 insertions(+), 43 deletions(-) create mode 100644 src/test/java/io/lettuce/core/ByteBufferCodec.java create mode 100644 src/test/java/io/lettuce/core/pubsub/PubSubEndpointUnitTests.java diff --git a/src/main/java/io/lettuce/core/pubsub/PubSubEndpoint.java b/src/main/java/io/lettuce/core/pubsub/PubSubEndpoint.java index 398f1e77c6..923fe51724 100644 --- a/src/main/java/io/lettuce/core/pubsub/PubSubEndpoint.java +++ b/src/main/java/io/lettuce/core/pubsub/PubSubEndpoint.java @@ -15,10 +15,7 @@ */ package io.lettuce.core.pubsub; -import java.util.Collection; -import java.util.HashSet; -import java.util.List; -import java.util.Set; +import java.util.*; import java.util.concurrent.CopyOnWriteArrayList; import io.lettuce.core.ClientOptions; @@ -38,8 +35,8 @@ public class PubSubEndpoint extends DefaultEndpoint { private static final Set ALLOWED_COMMANDS_SUBSCRIBED; private static final Set SUBSCRIBE_COMMANDS; private final List> listeners = new CopyOnWriteArrayList<>(); - private final Set channels; - private final Set patterns; + private final Set> channels; + private final Set> patterns; private volatile boolean subscribeWritten = false; static { @@ -94,12 +91,20 @@ protected List> getListeners() { return listeners; } + public boolean hasChannelSubscriptions() { + return !channels.isEmpty(); + } + public Set getChannels() { - return channels; + return unwrap(this.channels); + } + + public boolean hasPatternSubscriptions() { + return !patterns.isEmpty(); } public Set getPatterns() { - return patterns; + return unwrap(this.patterns); } @Override @@ -151,7 +156,7 @@ private static void validateCommandAllowed(RedisCommand command) { } private boolean isSubscribed() { - return subscribeWritten && (!channels.isEmpty() || !patterns.isEmpty()); + return subscribeWritten && (hasChannelSubscriptions() || hasPatternSubscriptions()); } public void notifyMessage(PubSubOutput output) { @@ -197,19 +202,78 @@ private void updateInternalState(PubSubOutput output) { // update internal state switch (output.type()) { case psubscribe: - patterns.add(output.pattern()); + patterns.add(new Wrapper<>(output.pattern())); break; case punsubscribe: - patterns.remove(output.pattern()); + patterns.remove(new Wrapper<>(output.pattern())); break; case subscribe: - channels.add(output.channel()); + channels.add(new Wrapper<>(output.channel())); break; case unsubscribe: - channels.remove(output.channel()); + channels.remove(new Wrapper<>(output.channel())); break; default: break; } } + + private Set unwrap(Set> wrapped) { + + Set result = new LinkedHashSet<>(wrapped.size()); + + for (Wrapper channel : wrapped) { + result.add(channel.name); + } + + return result; + } + + /** + * Comparison/equality wrapper with specific {@code byte[]} equals and hashCode implementations. + * + * @param + */ + static class Wrapper { + + protected final K name; + + public Wrapper(K name) { + this.name = name; + } + + @Override + public int hashCode() { + + if (name instanceof byte[]) { + return Arrays.hashCode((byte[]) name); + } + return name.hashCode(); + } + + @Override + public boolean equals(Object obj) { + + if (!(obj instanceof Wrapper)) { + return false; + } + + Wrapper that = (Wrapper) obj; + + if (name instanceof byte[] && that.name instanceof byte[]) { + return Arrays.equals((byte[]) name, (byte[]) that.name); + } + + return name.equals(that.name); + } + + @Override + public String toString() { + final StringBuffer sb = new StringBuffer(); + sb.append(getClass().getSimpleName()); + sb.append(" [name=").append(name); + sb.append(']'); + return sb.toString(); + } + } } diff --git a/src/main/java/io/lettuce/core/pubsub/StatefulRedisPubSubConnectionImpl.java b/src/main/java/io/lettuce/core/pubsub/StatefulRedisPubSubConnectionImpl.java index fd9d0d1b12..31eeecdfbf 100644 --- a/src/main/java/io/lettuce/core/pubsub/StatefulRedisPubSubConnectionImpl.java +++ b/src/main/java/io/lettuce/core/pubsub/StatefulRedisPubSubConnectionImpl.java @@ -120,11 +120,11 @@ protected List> resubscribe() { List> result = new ArrayList<>(); - if (!endpoint.getChannels().isEmpty()) { + if (endpoint.hasChannelSubscriptions()) { result.add(async().subscribe(toArray(endpoint.getChannels()))); } - if (!endpoint.getPatterns().isEmpty()) { + if (endpoint.hasPatternSubscriptions()) { result.add(async().psubscribe(toArray(endpoint.getPatterns()))); } diff --git a/src/test/java/io/lettuce/core/ByteBufferCodec.java b/src/test/java/io/lettuce/core/ByteBufferCodec.java new file mode 100644 index 0000000000..89f704a992 --- /dev/null +++ b/src/test/java/io/lettuce/core/ByteBufferCodec.java @@ -0,0 +1,52 @@ +/* + * Copyright 2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.lettuce.core; + +import java.nio.ByteBuffer; + +import io.lettuce.core.codec.RedisCodec; + +/** + * @author Mark Paluch + */ +public class ByteBufferCodec implements RedisCodec { + + @Override + public ByteBuffer decodeKey(ByteBuffer bytes) { + + ByteBuffer decoupled = ByteBuffer.allocate(bytes.remaining()); + decoupled.put(bytes); + return (ByteBuffer) decoupled.flip(); + } + + @Override + public ByteBuffer decodeValue(ByteBuffer bytes) { + + ByteBuffer decoupled = ByteBuffer.allocate(bytes.remaining()); + decoupled.put(bytes); + return (ByteBuffer) decoupled.flip(); + } + + @Override + public ByteBuffer encodeKey(ByteBuffer key) { + return key.asReadOnlyBuffer(); + } + + @Override + public ByteBuffer encodeValue(ByteBuffer value) { + return value.asReadOnlyBuffer(); + } +} diff --git a/src/test/java/io/lettuce/core/CustomCodecIntegrationTests.java b/src/test/java/io/lettuce/core/CustomCodecIntegrationTests.java index 8a620b39cc..41cc9660f1 100644 --- a/src/test/java/io/lettuce/core/CustomCodecIntegrationTests.java +++ b/src/test/java/io/lettuce/core/CustomCodecIntegrationTests.java @@ -170,32 +170,4 @@ public ByteBuffer encodeValue(Object value) { } } - class ByteBufferCodec implements RedisCodec { - - @Override - public ByteBuffer decodeKey(ByteBuffer bytes) { - - ByteBuffer decoupled = ByteBuffer.allocate(bytes.remaining()); - decoupled.put(bytes); - return (ByteBuffer) decoupled.flip(); - } - - @Override - public ByteBuffer decodeValue(ByteBuffer bytes) { - - ByteBuffer decoupled = ByteBuffer.allocate(bytes.remaining()); - decoupled.put(bytes); - return (ByteBuffer) decoupled.flip(); - } - - @Override - public ByteBuffer encodeKey(ByteBuffer key) { - return key.asReadOnlyBuffer(); - } - - @Override - public ByteBuffer encodeValue(ByteBuffer value) { - return value.asReadOnlyBuffer(); - } - } } diff --git a/src/test/java/io/lettuce/core/pubsub/PubSubEndpointUnitTests.java b/src/test/java/io/lettuce/core/pubsub/PubSubEndpointUnitTests.java new file mode 100644 index 0000000000..2580dc4c9e --- /dev/null +++ b/src/test/java/io/lettuce/core/pubsub/PubSubEndpointUnitTests.java @@ -0,0 +1,98 @@ +/* + * Copyright 2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.lettuce.core.pubsub; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.nio.ByteBuffer; + +import org.junit.jupiter.api.Test; + +import io.lettuce.core.ByteBufferCodec; +import io.lettuce.core.ClientOptions; +import io.lettuce.core.codec.ByteArrayCodec; +import io.lettuce.core.codec.RedisCodec; +import io.lettuce.core.codec.StringCodec; +import io.lettuce.test.resource.TestClientResources; + +/** + * Unit tests for {@link PubSubEndpoint}. + * + * @author Mark Paluch + */ +class PubSubEndpointUnitTests { + + @Test + void shouldRetainUniqueChannelNames() { + + PubSubEndpoint sut = new PubSubEndpoint<>(ClientOptions.create(), TestClientResources.get()); + + sut.notifyMessage(createMessage("subscribe", "channel1", StringCodec.UTF8)); + sut.notifyMessage(createMessage("subscribe", "channel1", StringCodec.UTF8)); + sut.notifyMessage(createMessage("subscribe", "channel1", StringCodec.UTF8)); + sut.notifyMessage(createMessage("subscribe", "channel2", StringCodec.UTF8)); + + assertThat(sut.getChannels()).hasSize(2).containsOnly("channel1", "channel2"); + } + + @Test + void shouldRetainUniqueBinaryChannelNames() { + + PubSubEndpoint sut = new PubSubEndpoint<>(ClientOptions.create(), TestClientResources.get()); + + sut.notifyMessage(createMessage("subscribe", "channel1", ByteArrayCodec.INSTANCE)); + sut.notifyMessage(createMessage("subscribe", "channel1", ByteArrayCodec.INSTANCE)); + sut.notifyMessage(createMessage("subscribe", "channel1", ByteArrayCodec.INSTANCE)); + sut.notifyMessage(createMessage("subscribe", "channel2", ByteArrayCodec.INSTANCE)); + + assertThat(sut.getChannels()).hasSize(2); + } + + @Test + void shouldRetainUniqueByteBufferChannelNames() { + + PubSubEndpoint sut = new PubSubEndpoint<>(ClientOptions.create(), TestClientResources.get()); + + sut.notifyMessage(createMessage("subscribe", "channel1", new ByteBufferCodec())); + sut.notifyMessage(createMessage("subscribe", "channel1", new ByteBufferCodec())); + sut.notifyMessage(createMessage("subscribe", "channel1", new ByteBufferCodec())); + sut.notifyMessage(createMessage("subscribe", "channel2", new ByteBufferCodec())); + + assertThat(sut.getChannels()).hasSize(2).containsOnly(ByteBuffer.wrap("channel1".getBytes()), + ByteBuffer.wrap("channel2".getBytes())); + } + + @Test + void addsAndRemovesChannels() { + + PubSubEndpoint sut = new PubSubEndpoint<>(ClientOptions.create(), TestClientResources.get()); + + sut.notifyMessage(createMessage("subscribe", "channel1", ByteArrayCodec.INSTANCE)); + sut.notifyMessage(createMessage("unsubscribe", "channel1", ByteArrayCodec.INSTANCE)); + + assertThat(sut.getChannels()).isEmpty(); + } + + private static PubSubOutput createMessage(String action, String channel, RedisCodec codec) { + + PubSubOutput output = new PubSubOutput<>(codec); + + output.set(ByteBuffer.wrap(action.getBytes())); + output.set(ByteBuffer.wrap(channel.getBytes())); + + return output; + } +}