diff --git a/src/main/java/com/lambdaworks/redis/RedisAsyncConnectionImpl.java b/src/main/java/com/lambdaworks/redis/RedisAsyncConnectionImpl.java index 9e42057d28..fd35433c3a 100644 --- a/src/main/java/com/lambdaworks/redis/RedisAsyncConnectionImpl.java +++ b/src/main/java/com/lambdaworks/redis/RedisAsyncConnectionImpl.java @@ -2,7 +2,7 @@ package com.lambdaworks.redis; -import static com.lambdaworks.redis.protocol.CommandType.*; +import static com.lambdaworks.redis.protocol.CommandType.EXEC; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; @@ -14,8 +14,18 @@ import com.lambdaworks.codec.Base16; import com.lambdaworks.redis.codec.RedisCodec; -import com.lambdaworks.redis.output.*; -import com.lambdaworks.redis.protocol.*; +import com.lambdaworks.redis.output.KeyStreamingChannel; +import com.lambdaworks.redis.output.KeyValueStreamingChannel; +import com.lambdaworks.redis.output.MultiOutput; +import com.lambdaworks.redis.output.ScoredValueStreamingChannel; +import com.lambdaworks.redis.output.ValueStreamingChannel; +import com.lambdaworks.redis.protocol.Command; +import com.lambdaworks.redis.protocol.CommandArgs; +import com.lambdaworks.redis.protocol.CommandOutput; +import com.lambdaworks.redis.protocol.CommandType; +import com.lambdaworks.redis.protocol.ConnectionWatchdog; +import com.lambdaworks.redis.protocol.RedisCommand; +import com.lambdaworks.redis.protocol.SetArgs; import io.netty.channel.ChannelHandler; /** @@ -245,6 +255,10 @@ public RedisFuture del(K... keys) { return dispatch(commandBuilder.del(keys)); } + public RedisFuture del(Iterable keys) { + return dispatch(commandBuilder.del(keys)); + } + @Override public RedisFuture discard() { if (multi != null) { @@ -539,6 +553,10 @@ public RedisFuture> mget(K... keys) { return dispatch(commandBuilder.mget(keys)); } + public RedisFuture> mget(Iterable keys) { + return dispatch(commandBuilder.mget(keys)); + } + @Override public RedisFuture mget(ValueStreamingChannel channel, K... keys) { return dispatch(commandBuilder.mget(channel, keys)); @@ -1579,8 +1597,7 @@ protected RedisCommand dispatch(CommandType type, CommandOutput RedisCommand dispatch(CommandType type, CommandOutput output, - CommandArgs args) { + protected RedisCommand dispatch(CommandType type, CommandOutput output, CommandArgs args) { Command cmd = new Command(type, output, args, multi != null); return dispatch(cmd); } diff --git a/src/main/java/com/lambdaworks/redis/RedisCommandBuilder.java b/src/main/java/com/lambdaworks/redis/RedisCommandBuilder.java index 6271472133..4c8461cc29 100644 --- a/src/main/java/com/lambdaworks/redis/RedisCommandBuilder.java +++ b/src/main/java/com/lambdaworks/redis/RedisCommandBuilder.java @@ -210,6 +210,11 @@ public Command del(K... keys) { return createCommand(DEL, new IntegerOutput(codec), args); } + public Command del(Iterable keys) { + CommandArgs args = new CommandArgs(codec).addKeys(keys); + return createCommand(DEL, new IntegerOutput(codec), args); + } + public Command discard() { return createCommand(DISCARD, new StatusOutput(codec)); } @@ -468,6 +473,11 @@ public Command> mget(K... keys) { return createCommand(MGET, new ValueListOutput(codec), args); } + public Command> mget(Iterable keys) { + CommandArgs args = new CommandArgs(codec).addKeys(keys); + return createCommand(MGET, new ValueListOutput(codec), args); + } + public Command mget(ValueStreamingChannel channel, K... keys) { CommandArgs args = new CommandArgs(codec).addKeys(keys); return createCommand(MGET, new ValueStreamingOutput(codec, channel), args); @@ -524,6 +534,7 @@ public Command pexpireat(K key, long timestamp) { public Command ping() { return createCommand(PING, new StatusOutput(codec)); } + public Command readOnly() { return createCommand(READONLY, new StatusOutput(codec)); } @@ -532,7 +543,6 @@ public Command readWrite() { return createCommand(READWRITE, new StatusOutput(codec)); } - public Command pttl(K key) { CommandArgs args = new CommandArgs(codec).addKey(key); return createCommand(PTTL, new IntegerOutput(codec), args); diff --git a/src/main/java/com/lambdaworks/redis/cluster/ClusterCompletionStage.java b/src/main/java/com/lambdaworks/redis/cluster/ClusterCompletionStage.java deleted file mode 100644 index 7da6f6855a..0000000000 --- a/src/main/java/com/lambdaworks/redis/cluster/ClusterCompletionStage.java +++ /dev/null @@ -1,29 +0,0 @@ -package com.lambdaworks.redis.cluster; - -import java.util.Collection; -import java.util.List; -import java.util.Map; -import java.util.concurrent.CompletionStage; - -import com.lambdaworks.redis.cluster.models.partitions.RedisClusterNode; - -/** - * Completes - * - * @author Mark Paluch - */ -public interface ClusterCompletionStage { - - Map> asMap(); - - Collection nodes(); - - Collection> stages(); - - CompletionStage get(RedisClusterNode redisClusterNode); - - CompletionStage any(); - - CompletionStage> all(); - -} diff --git a/src/main/java/com/lambdaworks/redis/cluster/PipelinedRedisFuture.java b/src/main/java/com/lambdaworks/redis/cluster/PipelinedRedisFuture.java new file mode 100644 index 0000000000..38038b6375 --- /dev/null +++ b/src/main/java/com/lambdaworks/redis/cluster/PipelinedRedisFuture.java @@ -0,0 +1,53 @@ +package com.lambdaworks.redis.cluster; + +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; + +import com.lambdaworks.redis.RedisFuture; + +/** + * @author Mark Paluch + */ +public class PipelinedRedisFuture extends CompletableFuture implements RedisFuture { + + private CountDownLatch latch = new CountDownLatch(1); + + public PipelinedRedisFuture(Map> executions, + Function, V> converter) { + + CompletableFuture.allOf(executions.values().toArray(new CompletableFuture[executions.size()])) + .thenRun(() -> complete(converter.apply(this))).exceptionally(throwable -> { + completeExceptionally(throwable); + return null; + }); + } + + @Override + public boolean complete(V value) { + boolean result = super.complete(value); + latch.countDown(); + return result; + } + + @Override + public boolean completeExceptionally(Throwable ex) { + + boolean value = super.completeExceptionally(ex); + latch.countDown(); + return value; + } + + @Override + public String getError() { + return null; + } + + @Override + public boolean await(long timeout, TimeUnit unit) throws InterruptedException { + return latch.await(timeout, unit); + } + +} diff --git a/src/main/java/com/lambdaworks/redis/cluster/RedisAdvancedClusterConnection.java b/src/main/java/com/lambdaworks/redis/cluster/RedisAdvancedClusterConnection.java index 10aeacc208..ae4a00b83f 100644 --- a/src/main/java/com/lambdaworks/redis/cluster/RedisAdvancedClusterConnection.java +++ b/src/main/java/com/lambdaworks/redis/cluster/RedisAdvancedClusterConnection.java @@ -1,9 +1,11 @@ package com.lambdaworks.redis.cluster; +import java.util.List; +import java.util.Map; import java.util.function.Predicate; import com.lambdaworks.redis.RedisClusterAsyncConnection; -import com.lambdaworks.redis.RedisClusterConnection; +import com.lambdaworks.redis.RedisFuture; import com.lambdaworks.redis.cluster.models.partitions.Partitions; import com.lambdaworks.redis.cluster.models.partitions.RedisClusterNode; @@ -63,4 +65,41 @@ default NodeSelectionAsyncOperations all() { * @return the Partitions/Cluster view. */ Partitions getPartitions(); + + /** + * Delete a key with pipelining. Cross-slot keys will result in multiple calls to the particular cluster nodes. + * + * @param keys the key + * @return RedisFuture<Long> integer-reply The number of keys that were removed. + */ + RedisFuture del(K... keys); + + /** + * Get the values of all the given keys with pipelining. Cross-slot keys will result in multiple calls to the particular + * cluster nodes. + * + * @param keys the key + * @return RedisFuture<List<V>> array-reply list of values at the specified keys. + */ + RedisFuture> mget(K... keys); + + /** + * Set multiple keys to multiple values with pipelining. Cross-slot keys will result in multiple calls to the particular + * cluster nodes. + * + * @param map the null + * @return RedisFuture<String> simple-string-reply always {@code OK} since {@code MSET} can't fail. + */ + RedisFuture mset(Map map); + + /** + * Set multiple keys to multiple values, only if none of the keys exist with pipelining. Cross-slot keys will result in + * multiple calls to the particular cluster nodes. + * + * @param map the null + * @return RedisFuture<Boolean> integer-reply specifically: + * + * {@code 1} if the all the keys were set. {@code 0} if no key was set (at least one key already existed). + */ + RedisFuture msetnx(Map map); } diff --git a/src/main/java/com/lambdaworks/redis/cluster/RedisAdvancedClusterConnectionImpl.java b/src/main/java/com/lambdaworks/redis/cluster/RedisAdvancedClusterConnectionImpl.java index 92cd9a297d..0ba35f1b93 100644 --- a/src/main/java/com/lambdaworks/redis/cluster/RedisAdvancedClusterConnectionImpl.java +++ b/src/main/java/com/lambdaworks/redis/cluster/RedisAdvancedClusterConnectionImpl.java @@ -1,11 +1,20 @@ package com.lambdaworks.redis.cluster; import java.lang.reflect.Proxy; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Callable; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; import java.util.function.Predicate; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; import com.lambdaworks.redis.RedisAsyncConnectionImpl; import com.lambdaworks.redis.RedisChannelWriter; +import com.lambdaworks.redis.RedisException; +import com.lambdaworks.redis.RedisFuture; import com.lambdaworks.redis.cluster.models.partitions.Partitions; import com.lambdaworks.redis.cluster.models.partitions.RedisClusterNode; import com.lambdaworks.redis.codec.RedisCodec; @@ -47,6 +56,157 @@ public NodeSelectionAsyncOperations nodes(Predicate pred new Class[] { NodeSelectionAsyncOperations.class }, h); } + @Override + public RedisFuture del(K... keys) { + + Map> partitioned = Maps.newHashMap(); + Map> executions = Maps.newHashMap(); + partition(partitioned, Arrays.asList(keys), Maps.newHashMap()); + + if (partitioned.size() < 2) { + return super.del(keys); + } + + for (Map.Entry> entry : partitioned.entrySet()) { + RedisFuture del = super.del(entry.getValue()); + executions.put(entry.getKey(), del); + } + + return new PipelinedRedisFuture<>(executions, objectPipelinedRedisFuture -> { + AtomicLong result = new AtomicLong(); + for (RedisFuture longRedisFuture : executions.values()) { + Long value = execute(() -> longRedisFuture.get()); + + if (value != null) { + result.getAndAdd(value); + } + } + + return result.get(); + }); + } + + protected void partition(Map> partitioned, Iterable keys, Map slots) { + for (K key : keys) { + int slot = SlotHash.getSlot(codec.encodeKey(key)); + slots.put(key, slot); + List list = commandPartition(partitioned, slot); + list.add(key); + } + } + + @Override + public RedisFuture> mget(K... keys) { + + Map> partitioned = Maps.newHashMap(); + Map slots = Maps.newHashMap(); + Map>> executions = Maps.newHashMap(); + partition(partitioned, Arrays.asList(keys), slots); + + if (partitioned.size() < 2) { + return super.mget(keys); + } + + for (Map.Entry> entry : partitioned.entrySet()) { + RedisFuture> mget = super.mget(entry.getValue()); + executions.put(entry.getKey(), mget); + } + + return new PipelinedRedisFuture<>(executions, objectPipelinedRedisFuture -> { + List result = Lists.newArrayList(); + for (K opKey : keys) { + int slot = slots.get(opKey); + + int position = partitioned.get(slot).indexOf(opKey); + RedisFuture> listRedisFuture = executions.get(slot); + result.add(execute(() -> listRedisFuture.get().get(position))); + } + + return result; + }); + } + + @Override + public RedisFuture mset(Map map) { + + Map> partitioned = Maps.newHashMap(); + Map slots = Maps.newHashMap(); + Map> executions = Maps.newHashMap(); + partition(partitioned, map.keySet(), slots); + + if (partitioned.size() < 2) { + return super.mset(map); + } + + for (Map.Entry> entry : partitioned.entrySet()) { + + Map op = Maps.newHashMap(); + entry.getValue().forEach(k -> op.put(k, map.get(k))); + + RedisFuture mset = super.mset(op); + executions.put(entry.getKey(), mset); + } + + return new PipelinedRedisFuture<>(executions, objectPipelinedRedisFuture -> { + for (RedisFuture listRedisFuture : executions.values()) { + return execute(() -> listRedisFuture.get()); + } + + return null; + }); + + } + + @Override + public RedisFuture msetnx(Map map) { + + Map> partitioned = Maps.newHashMap(); + Map slots = Maps.newHashMap(); + Map> executions = Maps.newHashMap(); + partition(partitioned, map.keySet(), slots); + + if (partitioned.size() < 2) { + return super.msetnx(map); + } + + for (Map.Entry> entry : partitioned.entrySet()) { + + Map op = Maps.newHashMap(); + entry.getValue().forEach(k -> op.put(k, map.get(k))); + + RedisFuture msetnx = super.msetnx(op); + executions.put(entry.getKey(), msetnx); + } + + return new PipelinedRedisFuture<>(executions, objectPipelinedRedisFuture -> { + for (RedisFuture listRedisFuture : executions.values()) { + Boolean b = execute(() -> listRedisFuture.get()); + if (b != null && b) { + return true; + } + } + + return false; + }); + + } + + private T execute(Callable function) { + try { + return function.call(); + } catch (Exception e) { + throw new RedisException(e); + } + + } + + private List commandPartition(Map> partitioned, int slot) { + if (!partitioned.containsKey(slot)) { + partitioned.put(slot, Lists.newArrayList()); + } + return partitioned.get(slot); + } + @Override public Partitions getPartitions() { return partitions; diff --git a/src/main/java/com/lambdaworks/redis/protocol/Command.java b/src/main/java/com/lambdaworks/redis/protocol/Command.java index d827ce9b1c..fcca6d84f7 100644 --- a/src/main/java/com/lambdaworks/redis/protocol/Command.java +++ b/src/main/java/com/lambdaworks/redis/protocol/Command.java @@ -66,7 +66,6 @@ public boolean isMulti() { return multi; } - /** * Wait up to the specified time for the command output to become available. * @@ -100,8 +99,8 @@ public CommandOutput getOutput() { */ @Override public void complete() { - latch.countDown(); - if (latch.getCount() == 0) { + + if (latch.getCount() == 1) { if (output == null) { complete(null); } else if (output.hasError()) { @@ -110,6 +109,17 @@ public void complete() { complete(output.get()); } } + latch.countDown(); + } + + @Override + public boolean completeExceptionally(Throwable ex) { + boolean result = false; + if (latch.getCount() == 1) { + result = super.completeExceptionally(ex); + } + latch.countDown(); + return result; } /** diff --git a/src/main/java/com/lambdaworks/redis/protocol/CommandArgs.java b/src/main/java/com/lambdaworks/redis/protocol/CommandArgs.java index e413031a45..aa2104aebb 100644 --- a/src/main/java/com/lambdaworks/redis/protocol/CommandArgs.java +++ b/src/main/java/com/lambdaworks/redis/protocol/CommandArgs.java @@ -2,7 +2,7 @@ package com.lambdaworks.redis.protocol; -import static java.lang.Math.*; +import static java.lang.Math.max; import java.nio.BufferOverflowException; import java.nio.ByteBuffer; @@ -47,6 +47,13 @@ public CommandArgs addKey(K key) { return write(codec.encodeKey(key)); } + public CommandArgs addKeys(Iterable keys) { + for (K key : keys) { + addKey(key); + } + return this; + } + public CommandArgs addKeys(K... keys) { for (K key : keys) { addKey(key); diff --git a/src/test/java/com/lambdaworks/SslTest.java b/src/test/java/com/lambdaworks/SslTest.java index 8ec6c20a0f..fe1ce51969 100644 --- a/src/test/java/com/lambdaworks/SslTest.java +++ b/src/test/java/com/lambdaworks/SslTest.java @@ -43,7 +43,6 @@ public void regularSsl() throws Exception { RedisConnection connection = redisClient.connect(redisUri); connection.set("key", "value"); assertThat(connection.get("key")).isEqualTo("value"); - connection.close(); } diff --git a/src/test/java/com/lambdaworks/redis/SentinelFailoverTest.java b/src/test/java/com/lambdaworks/redis/SentinelFailoverTest.java index 6d087daf1f..9a1a6e64e0 100644 --- a/src/test/java/com/lambdaworks/redis/SentinelFailoverTest.java +++ b/src/test/java/com/lambdaworks/redis/SentinelFailoverTest.java @@ -1,13 +1,19 @@ package com.lambdaworks.redis; -import static com.google.code.tempusfugit.temporal.Duration.*; -import static com.google.code.tempusfugit.temporal.Timeout.*; +import static com.google.code.tempusfugit.temporal.Duration.seconds; +import static com.google.code.tempusfugit.temporal.Timeout.timeout; import static com.lambdaworks.redis.TestSettings.port; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; -import org.junit.*; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; import com.google.code.tempusfugit.temporal.Condition; import com.google.code.tempusfugit.temporal.WaitFor; @@ -49,24 +55,30 @@ public void closeConnection() throws Exception { @Test public void connectToRedisUsingSentinel() throws Exception { - WaitFor.waitOrTimeout(new Condition() { - @Override - public boolean isSatisfied() { - return sentinelRule.hasConnectedSlaves(MASTER_WITH_SLAVE_ID); - } - }, timeout(seconds(20))); + waitForAvailableSlave(); RedisConnection connect = sentinelClient.connect(); assertThat(connect.ping()).isEqualToIgnoringCase("PONG"); connect.close(); this.sentinel.failover(MASTER_WITH_SLAVE_ID).get(); + waitForAvailableSlave(); RedisConnection connect2 = sentinelClient.connect(); assertThat(connect2.ping()).isEqualToIgnoringCase("PONG"); connect2.close(); } + protected void waitForAvailableSlave() throws InterruptedException, TimeoutException { + WaitFor.waitOrTimeout(new Condition() { + @Override + public boolean isSatisfied() { + return sentinelRule.hasConnectedSlaves(MASTER_WITH_SLAVE_ID) + && sentinel.getMasterAddrByName(MASTER_WITH_SLAVE_ID) != null; + } + }, timeout(seconds(20))); + } + protected static RedisClient getRedisSentinelClient() { return new RedisClient(RedisURI.Builder.sentinel(host, 26380, MASTER_WITH_SLAVE_ID).build()); } diff --git a/src/test/java/com/lambdaworks/redis/cluster/AdvancedClusterClientTest.java b/src/test/java/com/lambdaworks/redis/cluster/AdvancedClusterClientTest.java index 6807392b88..16609f60db 100644 --- a/src/test/java/com/lambdaworks/redis/cluster/AdvancedClusterClientTest.java +++ b/src/test/java/com/lambdaworks/redis/cluster/AdvancedClusterClientTest.java @@ -20,7 +20,9 @@ import com.google.code.tempusfugit.temporal.ThreadSleep; import com.google.code.tempusfugit.temporal.WaitFor; import com.google.common.collect.Lists; +import com.google.common.collect.Maps; import com.lambdaworks.redis.RedisClusterAsyncConnection; +import com.lambdaworks.redis.RedisFuture; import com.lambdaworks.redis.cluster.models.partitions.Partitions; import com.lambdaworks.redis.cluster.models.partitions.RedisClusterNode; @@ -126,8 +128,91 @@ public void testAsynchronicityOfMultiNodeExeccution() throws Exception { assertThat(execution.toCompletableFuture().get()).isEqualTo("OK"); } + CompletableFuture.allOf(eval.futures()).exceptionally(throwable -> { + return null; + }).get(); + for (CompletableFuture future : eval.futures()) { assertThat(future.isDone()).isTrue(); } } + + @Test + public void msetCrossSlot() throws Exception { + + Map mset = Maps.newHashMap(); + for (char c = 'a'; c < 'z'; c++) { + String key = new String(new char[] { c, c, c }); + mset.put(key, "value-" + key); + } + + RedisFuture result = connection.mset(mset); + + assertThat(result.get()).isEqualTo("OK"); + + for (String mykey : mset.keySet()) { + String s1 = connection.get(mykey).get(); + assertThat(s1).isEqualTo("value-" + mykey); + } + } + + @Test + public void msetnxCrossSlot() throws Exception { + + Map mset = Maps.newHashMap(); + for (char c = 'a'; c < 'z'; c++) { + String key = new String(new char[] { c, c, c }); + mset.put(key, "value-" + key); + } + + RedisFuture result = connection.msetnx(mset); + + assertThat(result.get()).isTrue(); + + for (String mykey : mset.keySet()) { + String s1 = connection.get(mykey).get(); + assertThat(s1).isEqualTo("value-" + mykey); + } + } + + @Test + public void mgetCrossSlot() throws Exception { + + msetCrossSlot(); + List keys = Lists.newArrayList(); + List expectation = Lists.newArrayList(); + for (char c = 'a'; c < 'z'; c++) { + String key = new String(new char[] { c, c, c }); + keys.add(key); + expectation.add("value-" + key); + } + + RedisFuture> result = connection.mget(keys.toArray(new String[keys.size()])); + + assertThat(result.get()).hasSize(keys.size()); + assertThat(result.get()).isEqualTo(expectation); + + } + + @Test + public void delCrossSlot() throws Exception { + + msetCrossSlot(); + List keys = Lists.newArrayList(); + for (char c = 'a'; c < 'z'; c++) { + String key = new String(new char[] { c, c, c }); + keys.add(key); + } + + RedisFuture result = connection.del(keys.toArray(new String[keys.size()])); + + assertThat(result.get()).isEqualTo(25); + + for (String mykey : keys) { + String s1 = connection.get(mykey).get(); + assertThat(s1).isNull(); + } + + } + }