From 4e2317a4efe75400fa86b16eab555034dd32b5cf Mon Sep 17 00:00:00 2001 From: Lari Hotari Date: Fri, 31 May 2024 03:25:52 +0300 Subject: [PATCH] [fix][ml] Fix race conditions in RangeCache (#22789) (cherry picked from commit c39f9f82b425c66c899f818583714c9c98d3e213) (cherry picked from commit 9a99e4586067307848179031a3446fa1e0044683) --- .../bookkeeper/mledger/impl/EntryImpl.java | 7 +- .../bookkeeper/mledger/util/RangeCache.java | 278 +++++++++++++----- .../mledger/util/RangeCacheTest.java | 63 ++-- 3 files changed, 254 insertions(+), 94 deletions(-) diff --git a/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/impl/EntryImpl.java b/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/impl/EntryImpl.java index 6512399173f0a..e53f408ca7563 100644 --- a/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/impl/EntryImpl.java +++ b/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/impl/EntryImpl.java @@ -27,9 +27,10 @@ import org.apache.bookkeeper.client.api.LedgerEntry; import org.apache.bookkeeper.mledger.Entry; import org.apache.bookkeeper.mledger.util.AbstractCASReferenceCounted; +import org.apache.bookkeeper.mledger.util.RangeCache; public final class EntryImpl extends AbstractCASReferenceCounted implements Entry, Comparable, - ReferenceCounted { + RangeCache.ValueWithKeyValidation { private static final Recycler RECYCLER = new Recycler() { @Override @@ -200,4 +201,8 @@ protected void deallocate() { recyclerHandle.recycle(this); } + @Override + public boolean matchesKey(PositionImpl key) { + return key.compareTo(ledgerId, entryId) == 0; + } } diff --git a/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/util/RangeCache.java b/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/util/RangeCache.java index d34857e5e5177..46d03bea1b5ad 100644 --- a/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/util/RangeCache.java +++ b/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/util/RangeCache.java @@ -19,31 +19,134 @@ package org.apache.bookkeeper.mledger.util; import static com.google.common.base.Preconditions.checkArgument; +import com.google.common.base.Predicate; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.Recycler; +import io.netty.util.Recycler.Handle; import io.netty.util.ReferenceCounted; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.concurrent.ConcurrentNavigableMap; import java.util.concurrent.ConcurrentSkipListMap; import java.util.concurrent.atomic.AtomicLong; +import org.apache.bookkeeper.mledger.util.RangeCache.ValueWithKeyValidation; import org.apache.commons.lang3.tuple.Pair; /** * Special type of cache where get() and delete() operations can be done over a range of keys. + * The implementation avoids locks and synchronization and relies on ConcurrentSkipListMap for storing the entries. + * Since there is no locks, there is a need to have a way to ensure that a single entry in the cache is removed + * exactly once. Removing an entry multiple times would result in the entries of the cache getting released too + * while they could still be in use. * * @param * Cache key. Needs to be Comparable * @param * Cache value */ -public class RangeCache, Value extends ReferenceCounted> { +public class RangeCache, Value extends ValueWithKeyValidation> { + public interface ValueWithKeyValidation extends ReferenceCounted { + boolean matchesKey(T key); + } + // Map from key to nodes inside the linked list - private final ConcurrentNavigableMap entries; + private final ConcurrentNavigableMap> entries; private AtomicLong size; // Total size of values stored in cache private final Weighter weighter; // Weighter object used to extract the size from values private final TimestampExtractor timestampExtractor; // Extract the timestamp associated with a value + /** + * Wrapper around the value to store in Map. This is needed to ensure that a specific instance can be removed from + * the map by calling the {@link Map#remove(Object, Object)} method. Certain race conditions could result in the + * wrong value being removed from the map. The instances of this class are recycled to avoid creating new objects. + */ + private static class IdentityWrapper { + private final Handle recyclerHandle; + private static final Recycler RECYCLER = new Recycler() { + @Override + protected IdentityWrapper newObject(Handle recyclerHandle) { + return new IdentityWrapper(recyclerHandle); + } + }; + private K key; + private V value; + + private IdentityWrapper(Handle recyclerHandle) { + this.recyclerHandle = recyclerHandle; + } + + static IdentityWrapper create(K key, V value) { + IdentityWrapper identityWrapper = RECYCLER.get(); + identityWrapper.key = key; + identityWrapper.value = value; + return identityWrapper; + } + + K getKey() { + return key; + } + + V getValue() { + return value; + } + + void recycle() { + value = null; + recyclerHandle.recycle(this); + } + + @Override + public boolean equals(Object o) { + // only match exact identity of the value + return this == o; + } + + @Override + public int hashCode() { + return Objects.hashCode(key); + } + } + + /** + * Mutable object to store the number of entries and the total size removed from the cache. The instances + * are recycled to avoid creating new instances. + */ + private static class RemovalCounters { + private final Handle recyclerHandle; + private static final Recycler RECYCLER = new Recycler() { + @Override + protected RemovalCounters newObject(Handle recyclerHandle) { + return new RemovalCounters(recyclerHandle); + } + }; + int removedEntries; + long removedSize; + private RemovalCounters(Handle recyclerHandle) { + this.recyclerHandle = recyclerHandle; + } + + static RemovalCounters create() { + RemovalCounters results = RECYCLER.get(); + results.removedEntries = 0; + results.removedSize = 0; + return results; + } + + void recycle() { + removedEntries = 0; + removedSize = 0; + recyclerHandle.recycle(this); + } + + public void entryRemoved(long size) { + removedSize += size; + removedEntries++; + } + } + /** * Construct a new RangeLruCache with default Weighter. */ @@ -68,18 +171,23 @@ public RangeCache(Weighter weighter, TimestampExtractor timestampE * Insert. * * @param key - * @param value - * ref counted value with at least 1 ref to pass on the cache + * @param value ref counted value with at least 1 ref to pass on the cache * @return whether the entry was inserted in the cache */ public boolean put(Key key, Value value) { // retain value so that it's not released before we put it in the cache and calculate the weight value.retain(); try { - if (entries.putIfAbsent(key, value) == null) { + if (!value.matchesKey(key)) { + throw new IllegalArgumentException("Value '" + value + "' does not match key '" + key + "'"); + } + IdentityWrapper newWrapper = IdentityWrapper.create(key, value); + if (entries.putIfAbsent(key, newWrapper) == null) { size.addAndGet(weighter.getSize(value)); return true; } else { + // recycle the new wrapper as it was not used + newWrapper.recycle(); return false; } } finally { @@ -91,16 +199,37 @@ public boolean exists(Key key) { return key != null ? entries.containsKey(key) : true; } + /** + * Get the value associated with the key and increment the reference count of it. + * The caller is responsible for releasing the reference. + */ public Value get(Key key) { - Value value = entries.get(key); - if (value == null) { + return getValue(key, entries.get(key)); + } + + private Value getValue(Key key, IdentityWrapper valueWrapper) { + if (valueWrapper == null) { return null; } else { + if (valueWrapper.getKey() != key) { + // the wrapper has been recycled and contains another key + return null; + } + Value value = valueWrapper.getValue(); try { value.retain(); + } catch (IllegalReferenceCountException e) { + // Value was already deallocated + return null; + } + // check that the value matches the key and that there's at least 2 references to it since + // the cache should be holding one reference and a new reference was just added in this method + if (value.refCnt() > 1 && value.matchesKey(key)) { return value; - } catch (Throwable t) { - // Value was already destroyed between get() and retain() + } else { + // Value or IdentityWrapper was recycled and already contains another value + // release the reference added in this method + value.release(); return null; } } @@ -118,12 +247,10 @@ public Collection getRange(Key first, Key last) { List values = new ArrayList(); // Return the values of the entries found in cache - for (Value value : entries.subMap(first, true, last, true).values()) { - try { - value.retain(); + for (Map.Entry> entry : entries.subMap(first, true, last, true).entrySet()) { + Value value = getValue(entry.getKey(), entry.getValue()); + if (value != null) { values.add(value); - } catch (Throwable t) { - // Value was already destroyed between get() and retain() } } @@ -138,25 +265,65 @@ public Collection getRange(Key first, Key last) { * @return an pair of ints, containing the number of removed entries and the total size */ public Pair removeRange(Key first, Key last, boolean lastInclusive) { - Map subMap = entries.subMap(first, true, last, lastInclusive); + RemovalCounters counters = RemovalCounters.create(); + Map> subMap = entries.subMap(first, true, last, lastInclusive); + for (Map.Entry> entry : subMap.entrySet()) { + removeEntry(entry, counters); + } + return handleRemovalResult(counters); + } - int removedEntries = 0; - long removedSize = 0; + enum RemoveEntryResult { + ENTRY_REMOVED, + CONTINUE_LOOP, + BREAK_LOOP; + } - for (Key key : subMap.keySet()) { - Value value = entries.remove(key); - if (value == null) { - continue; - } + private RemoveEntryResult removeEntry(Map.Entry> entry, RemovalCounters counters) { + return removeEntry(entry, counters, (x) -> true); + } - removedSize += weighter.getSize(value); + private RemoveEntryResult removeEntry(Map.Entry> entry, RemovalCounters counters, + Predicate removeCondition) { + Key key = entry.getKey(); + IdentityWrapper identityWrapper = entry.getValue(); + if (identityWrapper.getKey() != key) { + // the wrapper has been recycled and contains another key + return RemoveEntryResult.CONTINUE_LOOP; + } + Value value = identityWrapper.getValue(); + try { + // add extra retain to avoid value being released while we are removing it + value.retain(); + } catch (IllegalReferenceCountException e) { + // Value was already released + return RemoveEntryResult.CONTINUE_LOOP; + } + try { + if (!removeCondition.test(value)) { + return RemoveEntryResult.BREAK_LOOP; + } + // check that the value hasn't been recycled in between + // there should be at least 2 references since this method adds one and the cache should have one + // it is valid that the value contains references even after the key has been removed from the cache + if (value.refCnt() > 1 && value.matchesKey(key) && entries.remove(key, identityWrapper)) { + identityWrapper.recycle(); + counters.entryRemoved(weighter.getSize(value)); + // remove the cache reference + value.release(); + } + } finally { + // remove the extra retain value.release(); - ++removedEntries; } + return RemoveEntryResult.ENTRY_REMOVED; + } - size.addAndGet(-removedSize); - - return Pair.of(removedEntries, removedSize); + private Pair handleRemovalResult(RemovalCounters counters) { + size.addAndGet(-counters.removedSize); + Pair result = Pair.of(counters.removedEntries, counters.removedSize); + counters.recycle(); + return result; } /** @@ -166,24 +333,15 @@ public Pair removeRange(Key first, Key last, boolean lastInclusiv */ public Pair evictLeastAccessedEntries(long minSize) { checkArgument(minSize > 0); - - long removedSize = 0; - int removedEntries = 0; - - while (removedSize < minSize) { - Map.Entry entry = entries.pollFirstEntry(); + RemovalCounters counters = RemovalCounters.create(); + while (counters.removedSize < minSize) { + Map.Entry> entry = entries.firstEntry(); if (entry == null) { break; } - - Value value = entry.getValue(); - ++removedEntries; - removedSize += weighter.getSize(value); - value.release(); + removeEntry(entry, counters); } - - size.addAndGet(-removedSize); - return Pair.of(removedEntries, removedSize); + return handleRemovalResult(counters); } /** @@ -192,27 +350,18 @@ public Pair evictLeastAccessedEntries(long minSize) { * @return the tota */ public Pair evictLEntriesBeforeTimestamp(long maxTimestamp) { - long removedSize = 0; - int removedCount = 0; - + RemovalCounters counters = RemovalCounters.create(); while (true) { - Map.Entry entry = entries.firstEntry(); - if (entry == null || timestampExtractor.getTimestamp(entry.getValue()) > maxTimestamp) { + Map.Entry> entry = entries.firstEntry(); + if (entry == null) { break; } - Value value = entry.getValue(); - boolean removeHits = entries.remove(entry.getKey(), value); - if (!removeHits) { + if (removeEntry(entry, counters, value -> timestampExtractor.getTimestamp(value) <= maxTimestamp) + == RemoveEntryResult.BREAK_LOOP) { break; } - - removedSize += weighter.getSize(value); - removedCount++; - value.release(); } - - size.addAndGet(-removedSize); - return Pair.of(removedCount, removedSize); + return handleRemovalResult(counters); } /** @@ -231,23 +380,16 @@ public long getSize() { * * @return size of removed entries */ - public synchronized Pair clear() { - long removedSize = 0; - int removedCount = 0; - + public Pair clear() { + RemovalCounters counters = RemovalCounters.create(); while (true) { - Map.Entry entry = entries.pollFirstEntry(); + Map.Entry> entry = entries.firstEntry(); if (entry == null) { break; } - Value value = entry.getValue(); - removedSize += weighter.getSize(value); - removedCount++; - value.release(); + removeEntry(entry, counters); } - - size.getAndAdd(-removedSize); - return Pair.of(removedCount, removedSize); + return handleRemovalResult(counters); } /** diff --git a/managed-ledger/src/test/java/org/apache/bookkeeper/mledger/util/RangeCacheTest.java b/managed-ledger/src/test/java/org/apache/bookkeeper/mledger/util/RangeCacheTest.java index 8ce0db4ac4caa..01b3c67bf1113 100644 --- a/managed-ledger/src/test/java/org/apache/bookkeeper/mledger/util/RangeCacheTest.java +++ b/managed-ledger/src/test/java/org/apache/bookkeeper/mledger/util/RangeCacheTest.java @@ -23,25 +23,30 @@ import static org.testng.Assert.assertNull; import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; - import com.google.common.collect.Lists; import io.netty.util.AbstractReferenceCounted; import io.netty.util.ReferenceCounted; -import org.apache.commons.lang3.tuple.Pair; -import org.testng.annotations.Test; -import java.util.UUID; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; +import lombok.Cleanup; +import org.apache.commons.lang3.tuple.Pair; +import org.testng.annotations.Test; public class RangeCacheTest { - class RefString extends AbstractReferenceCounted implements ReferenceCounted { + class RefString extends AbstractReferenceCounted implements RangeCache.ValueWithKeyValidation { String s; + Integer matchingKey; RefString(String s) { + this(s, null); + } + + RefString(String s, Integer matchingKey) { super(); this.s = s; + this.matchingKey = matchingKey != null ? matchingKey : Integer.parseInt(s); setRefCnt(1); } @@ -65,6 +70,11 @@ public boolean equals(Object obj) { return false; } + + @Override + public boolean matchesKey(Integer key) { + return matchingKey.equals(key); + } } @Test @@ -119,8 +129,8 @@ public void simple() { public void customWeighter() { RangeCache cache = new RangeCache<>(value -> value.s.length(), x -> 0); - cache.put(0, new RefString("zero")); - cache.put(1, new RefString("one")); + cache.put(0, new RefString("zero", 0)); + cache.put(1, new RefString("one", 1)); assertEquals(cache.getSize(), 7); assertEquals(cache.getNumberOfEntries(), 2); @@ -132,9 +142,9 @@ public void customTimeExtraction() { RangeCache cache = new RangeCache<>(value -> value.s.length(), x -> x.s.length()); cache.put(1, new RefString("1")); - cache.put(2, new RefString("22")); - cache.put(3, new RefString("333")); - cache.put(4, new RefString("4444")); + cache.put(22, new RefString("22")); + cache.put(333, new RefString("333")); + cache.put(4444, new RefString("4444")); assertEquals(cache.getSize(), 10); assertEquals(cache.getNumberOfEntries(), 4); @@ -151,12 +161,12 @@ public void customTimeExtraction() { public void doubleInsert() { RangeCache cache = new RangeCache<>(); - RefString s0 = new RefString("zero"); + RefString s0 = new RefString("zero", 0); assertEquals(s0.refCnt(), 1); assertTrue(cache.put(0, s0)); assertEquals(s0.refCnt(), 1); - cache.put(1, new RefString("one")); + cache.put(1, new RefString("one", 1)); assertEquals(cache.getSize(), 2); assertEquals(cache.getNumberOfEntries(), 2); @@ -164,7 +174,7 @@ public void doubleInsert() { assertEquals(s.s, "one"); assertEquals(s.refCnt(), 2); - RefString s1 = new RefString("uno"); + RefString s1 = new RefString("uno", 1); assertEquals(s1.refCnt(), 1); assertFalse(cache.put(1, s1)); assertEquals(s1.refCnt(), 1); @@ -201,10 +211,10 @@ public void getRange() { public void eviction() { RangeCache cache = new RangeCache<>(value -> value.s.length(), x -> 0); - cache.put(0, new RefString("zero")); - cache.put(1, new RefString("one")); - cache.put(2, new RefString("two")); - cache.put(3, new RefString("three")); + cache.put(0, new RefString("zero", 0)); + cache.put(1, new RefString("one", 1)); + cache.put(2, new RefString("two", 2)); + cache.put(3, new RefString("three", 3)); // This should remove the LRU entries: 0, 1 whose combined size is 7 assertEquals(cache.evictLeastAccessedEntries(5), Pair.of(2, (long) 7)); @@ -276,20 +286,23 @@ public void evictions() { } @Test - public void testInParallel() { - RangeCache cache = new RangeCache<>(value -> value.s.length(), x -> 0); - ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor(); - executor.scheduleWithFixedDelay(cache::clear, 10, 10, TimeUnit.MILLISECONDS); - for (int i = 0; i < 1000; i++) { - cache.put(UUID.randomUUID().toString(), new RefString("zero")); + public void testPutWhileClearIsCalledConcurrently() { + RangeCache cache = new RangeCache<>(value -> value.s.length(), x -> 0); + int numberOfThreads = 4; + @Cleanup("shutdownNow") + ScheduledExecutorService executor = Executors.newScheduledThreadPool(numberOfThreads); + for (int i = 0; i < numberOfThreads; i++) { + executor.scheduleWithFixedDelay(cache::clear, 0, 1, TimeUnit.MILLISECONDS); + } + for (int i = 0; i < 100000; i++) { + cache.put(i, new RefString(String.valueOf(i))); } - executor.shutdown(); } @Test public void testPutSameObj() { RangeCache cache = new RangeCache<>(value -> value.s.length(), x -> 0); - RefString s0 = new RefString("zero"); + RefString s0 = new RefString("zero", 0); assertEquals(s0.refCnt(), 1); assertTrue(cache.put(0, s0)); assertFalse(cache.put(0, s0));