From da9c463210a5bf385505b355b27785b22e5684af Mon Sep 17 00:00:00 2001 From: M Sazzadul Hoque <7600764+sazzad16@users.noreply.github.com> Date: Thu, 28 Dec 2023 18:52:36 +0600 Subject: [PATCH] Initial support for client-side caching (#3658) --- .../redis/clients/jedis/ClientSideCache.java | 58 +++++++++++++++ .../java/redis/clients/jedis/Connection.java | 21 +++++- .../clients/jedis/JedisClientSideCache.java | 45 ++++++++++++ .../java/redis/clients/jedis/Protocol.java | 27 ++++++- .../clients/jedis/util/RedisInputStream.java | 19 +++++ .../jedis/JedisClientSideCacheTest.java | 70 +++++++++++++++++++ 6 files changed, 238 insertions(+), 2 deletions(-) create mode 100644 src/main/java/redis/clients/jedis/ClientSideCache.java create mode 100644 src/main/java/redis/clients/jedis/JedisClientSideCache.java create mode 100644 src/test/java/redis/clients/jedis/JedisClientSideCacheTest.java diff --git a/src/main/java/redis/clients/jedis/ClientSideCache.java b/src/main/java/redis/clients/jedis/ClientSideCache.java new file mode 100644 index 0000000000..5dd31b17e9 --- /dev/null +++ b/src/main/java/redis/clients/jedis/ClientSideCache.java @@ -0,0 +1,58 @@ +package redis.clients.jedis; + +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import redis.clients.jedis.exceptions.JedisException; +import redis.clients.jedis.util.SafeEncoder; + +public class ClientSideCache { + + private final Map cache = new HashMap<>(); + + protected ClientSideCache() { + } + + protected void invalidateKeys(List list) { + if (list == null) { + cache.clear(); + return; + } + + list.forEach(this::invalidateKey); + } + + private void invalidateKey(Object key) { + if (key instanceof byte[]) { + cache.remove(convertKey((byte[]) key)); + } else { + throw new JedisException("" + key.getClass().getSimpleName() + " is not supported. Value: " + String.valueOf(key)); + } + } + + protected void setKey(Object key, Object value) { + cache.put(getMapKey(key), value); + } + + protected T getValue(Object key) { + return (T) getMapValue(key); + } + + private Object getMapValue(Object key) { + return cache.get(getMapKey(key)); + } + + private ByteBuffer getMapKey(Object key) { + if (key instanceof byte[]) { + return convertKey((byte[]) key); + } else { + return convertKey(SafeEncoder.encode(String.valueOf(key))); + } + } + + private static ByteBuffer convertKey(byte[] b) { + return ByteBuffer.wrap(b); + } +} diff --git a/src/main/java/redis/clients/jedis/Connection.java b/src/main/java/redis/clients/jedis/Connection.java index 50243e20d7..58bc941706 100644 --- a/src/main/java/redis/clients/jedis/Connection.java +++ b/src/main/java/redis/clients/jedis/Connection.java @@ -34,6 +34,7 @@ public class Connection implements Closeable { private Socket socket; private RedisOutputStream outputStream; private RedisInputStream inputStream; + private ClientSideCache clientSideCache; private int soTimeout = 0; private int infiniteSoTimeout = 0; private boolean broken = false; @@ -121,6 +122,10 @@ public void rollbackTimeout() { } } + final void setClientSideCache(ClientSideCache clientSideCache) { + this.clientSideCache = clientSideCache; + } + public Object executeCommand(final ProtocolCommand cmd) { return executeCommand(new CommandArguments(cmd)); } @@ -347,9 +352,10 @@ protected Object readProtocolWithCheckingBroken() { } try { + Protocol.readPushes(inputStream, clientSideCache); return Protocol.read(inputStream); // Object read = Protocol.read(inputStream); -// System.out.println(SafeEncoder.encodeObject(read)); +// System.out.println("REPLY: " + SafeEncoder.encodeObject(read)); // return read; } catch (JedisConnectionException exc) { broken = true; @@ -370,6 +376,19 @@ public List getMany(final int count) { return responses; } + protected void readPushesWithCheckingBroken() { + if (broken) { + throw new JedisConnectionException("Attempting to read pushes from a broken connection"); + } + + try { + Protocol.readPushes(inputStream, clientSideCache); + } catch (JedisConnectionException exc) { + broken = true; + throw exc; + } + } + /** * Check if the client name libname, libver, characters are legal * @param info the name diff --git a/src/main/java/redis/clients/jedis/JedisClientSideCache.java b/src/main/java/redis/clients/jedis/JedisClientSideCache.java new file mode 100644 index 0000000000..73f2a71124 --- /dev/null +++ b/src/main/java/redis/clients/jedis/JedisClientSideCache.java @@ -0,0 +1,45 @@ +package redis.clients.jedis; + +import redis.clients.jedis.exceptions.JedisException; + +public class JedisClientSideCache extends Jedis { + + private final ClientSideCache cache; + + public JedisClientSideCache(final HostAndPort hostPort, final JedisClientConfig config) { + this(hostPort, config, new ClientSideCache()); + } + + public JedisClientSideCache(final HostAndPort hostPort, final JedisClientConfig config, + ClientSideCache cache) { + super(hostPort, config); + if (config.getRedisProtocol() != RedisProtocol.RESP3) { + throw new JedisException("Client side caching is only supported with RESP3."); + } + + this.cache = cache; + this.connection.setClientSideCache(cache); + clientTrackingOn(); + } + + private void clientTrackingOn() { + String reply = connection.executeCommand(new CommandObject<>( + new CommandArguments(Protocol.Command.CLIENT).add("TRACKING").add("ON").add("BCAST"), + BuilderFactory.STRING)); + if (!"OK".equals(reply)) { + throw new JedisException("Could not enable client tracking. Reply: " + reply); + } + } + + @Override + public String get(String key) { + connection.readPushesWithCheckingBroken(); + String cachedValue = cache.getValue(key); + if (cachedValue != null) return cachedValue; + + String value = super.get(key); + if (value != null) cache.setKey(key, value); + return value; + } + +} diff --git a/src/main/java/redis/clients/jedis/Protocol.java b/src/main/java/redis/clients/jedis/Protocol.java index 234b73bda9..0e276c0d93 100644 --- a/src/main/java/redis/clients/jedis/Protocol.java +++ b/src/main/java/redis/clients/jedis/Protocol.java @@ -4,8 +4,10 @@ import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Locale; +import java.util.Objects; import redis.clients.jedis.exceptions.*; import redis.clients.jedis.args.Rawable; @@ -57,6 +59,8 @@ public final class Protocol { private static final String WRONGPASS_PREFIX = "WRONGPASS"; private static final String NOPERM_PREFIX = "NOPERM"; + private static final byte[] INVALIDATE_BYTES = SafeEncoder.encode("invalidate"); + private Protocol() { throw new InstantiationError("Must not instantiate this class"); } @@ -133,7 +137,7 @@ private static String[] parseTargetHostAndSlot(String clusterRedirectResponse) { private static Object process(final RedisInputStream is) { final byte b = is.readByte(); - //System.out.println((char) b); + //System.out.println("BYTE: " + (char) b); switch (b) { case PLUS_BYTE: return is.readLineBytes(); @@ -167,6 +171,15 @@ private static Object process(final RedisInputStream is) { } } + private static void processPush(final RedisInputStream is, ClientSideCache cache) { + List list = processMultiBulkReply(is); + //System.out.println("PUSH: " + SafeEncoder.encodeObject(list)); + if (list.size() == 2 && list.get(0) instanceof byte[] + && Arrays.equals(INVALIDATE_BYTES, (byte[]) list.get(0))) { + cache.invalidateKeys((List) list.get(1)); + } + } + private static byte[] processBulkReply(final RedisInputStream is) { final int len = is.readIntCrLf(); if (len == -1) { @@ -193,11 +206,13 @@ private static byte[] processBulkReply(final RedisInputStream is) { private static List processMultiBulkReply(final RedisInputStream is) { // private static List processMultiBulkReply(final int num, final RedisInputStream is) { final int num = is.readIntCrLf(); + //System.out.println("MULTI BULK: " + num); if (num == -1) return null; final List ret = new ArrayList<>(num); for (int i = 0; i < num; i++) { try { ret.add(process(is)); + //System.out.println("MULTI >> " + (i+1) + ": " + SafeEncoder.encodeObject(ret.get(i))); } catch (JedisDataException e) { ret.add(e); } @@ -221,6 +236,16 @@ public static Object read(final RedisInputStream is) { return process(is); } + static void readPushes(final RedisInputStream is, final ClientSideCache cache) { + if (cache != null) { + //System.out.println("PEEK: " + is.peekByte()); + while (Objects.equals(GREATER_THAN_BYTE, is.peekByte())) { + is.readByte(); + processPush(is, cache); + } + } + } + public static final byte[] toByteArray(final boolean value) { return value ? BYTES_TRUE : BYTES_FALSE; } diff --git a/src/main/java/redis/clients/jedis/util/RedisInputStream.java b/src/main/java/redis/clients/jedis/util/RedisInputStream.java index a0dad9d437..094ec762d8 100644 --- a/src/main/java/redis/clients/jedis/util/RedisInputStream.java +++ b/src/main/java/redis/clients/jedis/util/RedisInputStream.java @@ -43,6 +43,11 @@ public RedisInputStream(InputStream in) { this(in, INPUT_BUFFER_SIZE); } + public Byte peekByte() { + ensureFillSafe(); + return buf[count]; + } + public byte readByte() throws JedisConnectionException { ensureFill(); return buf[count++]; @@ -252,4 +257,18 @@ private void ensureFill() throws JedisConnectionException { } } } + + private void ensureFillSafe() { + if (count >= limit) { + try { + limit = in.read(buf); + count = 0; + if (limit == -1) { + throw new JedisConnectionException("Unexpected end of stream."); + } + } catch (IOException e) { + // do nothing + } + } + } } diff --git a/src/test/java/redis/clients/jedis/JedisClientSideCacheTest.java b/src/test/java/redis/clients/jedis/JedisClientSideCacheTest.java new file mode 100644 index 0000000000..2375fa5153 --- /dev/null +++ b/src/test/java/redis/clients/jedis/JedisClientSideCacheTest.java @@ -0,0 +1,70 @@ +package redis.clients.jedis; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.InOrder; +import org.mockito.Mockito; + +public class JedisClientSideCacheTest { + + protected static final HostAndPort hnp = HostAndPorts.getRedisServers().get(1); + + protected Jedis jedis; + + @Before + public void setUp() throws Exception { + jedis = new Jedis(hnp, DefaultJedisClientConfig.builder().timeoutMillis(500).password("foobared").build()); + jedis.flushAll(); + } + + @After + public void tearDown() throws Exception { + jedis.close(); + } + + private static final JedisClientConfig configForCache = DefaultJedisClientConfig.builder() + .resp3().socketTimeoutMillis(20).password("foobared").build(); + + @Test + public void simple() { + try (JedisClientSideCache jCache = new JedisClientSideCache(hnp, configForCache)) { + jedis.set("foo", "bar"); + assertEquals("bar", jCache.get("foo")); + jedis.del("foo"); + assertNull(jCache.get("foo")); + } + } + + @Test + public void simpleMock() { + ClientSideCache cache = Mockito.mock(ClientSideCache.class); + try (JedisClientSideCache jCache = new JedisClientSideCache(hnp, configForCache, cache)) { + jedis.set("foo", "bar"); + assertEquals("bar", jCache.get("foo")); + jedis.del("foo"); + assertNull(jCache.get("foo")); + } + + InOrder inOrder = Mockito.inOrder(cache); + inOrder.verify(cache).invalidateKeys(Mockito.notNull()); + inOrder.verify(cache).getValue("foo"); + inOrder.verify(cache).setKey("foo", "bar"); + inOrder.verify(cache).invalidateKeys(Mockito.notNull()); + inOrder.verify(cache).getValue("foo"); + inOrder.verifyNoMoreInteractions(); + } + + @Test + public void flushall() { + try (JedisClientSideCache jCache = new JedisClientSideCache(hnp, configForCache)) { + jedis.set("foo", "bar"); + assertEquals("bar", jCache.get("foo")); + jedis.flushAll(); + assertNull(jCache.get("foo")); + } + } +}