Skip to content

Commit

Permalink
[improve][ml] RangeCache refactoring: test race conditions and preven…
Browse files Browse the repository at this point in the history
…t endless loops (#22814)

(cherry picked from commit e731674)
  • Loading branch information
lhotari committed Jun 3, 2024
1 parent a8b8324 commit 99aed3a
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,32 +28,36 @@
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentNavigableMap;
import java.util.concurrent.ConcurrentSkipListMap;
import java.util.concurrent.atomic.AtomicLong;
import lombok.extern.slf4j.Slf4j;
import org.apache.bookkeeper.mledger.util.RangeCache.ValueWithKeyValidation;
import org.apache.commons.lang3.tuple.Pair;

/**
* Special type of cache where get() and delete() operations can be done over a range of keys.
* The implementation avoids locks and synchronization and relies on ConcurrentSkipListMap for storing the entries.
* Since there is no locks, there is a need to have a way to ensure that a single entry in the cache is removed
* exactly once. Removing an entry multiple times would result in the entries of the cache getting released too
* while they could still be in use.
* The implementation avoids locks and synchronization by relying on ConcurrentSkipListMap for storing the entries.
* Since there are no locks, it's necessary to ensure that a single entry in the cache is removed exactly once.
* Removing an entry multiple times could result in the entries of the cache being released multiple times,
* even while they are still in use. This is prevented by using a custom wrapper around the value to store in the map
* that ensures that the value is removed from the map only if the exact same instance is present in the map.
* There's also a check that ensures that the value matches the key. This is used to detect races without impacting
* consistency.
*
* @param <Key>
* Cache key. Needs to be Comparable
* @param <Value>
* Cache value
*/
@Slf4j
public class RangeCache<Key extends Comparable<Key>, Value extends ValueWithKeyValidation<Key>> {
public interface ValueWithKeyValidation<T> extends ReferenceCounted {
boolean matchesKey(T key);
}

// Map from key to nodes inside the linked list
private final ConcurrentNavigableMap<Key, IdentityWrapper<Key, Value>> entries;
private final ConcurrentNavigableMap<Key, EntryWrapper<Key, Value>> entries;
private AtomicLong size; // Total size of values stored in cache
private final Weighter<Value> weighter; // Weighter object used to extract the size from values
private final TimestampExtractor<Value> timestampExtractor; // Extract the timestamp associated with a value
Expand All @@ -63,51 +67,53 @@ public interface ValueWithKeyValidation<T> extends ReferenceCounted {
* the map by calling the {@link Map#remove(Object, Object)} method. Certain race conditions could result in the
* wrong value being removed from the map. The instances of this class are recycled to avoid creating new objects.
*/
private static class IdentityWrapper<K, V> {
private final Handle<IdentityWrapper> recyclerHandle;
private static final Recycler<IdentityWrapper> RECYCLER = new Recycler<IdentityWrapper>() {
private static class EntryWrapper<K, V> {
private final Handle<EntryWrapper> recyclerHandle;
private static final Recycler<EntryWrapper> RECYCLER = new Recycler<EntryWrapper>() {
@Override
protected IdentityWrapper newObject(Handle<IdentityWrapper> recyclerHandle) {
return new IdentityWrapper(recyclerHandle);
protected EntryWrapper newObject(Handle<EntryWrapper> recyclerHandle) {
return new EntryWrapper(recyclerHandle);
}
};
private K key;
private V value;
long size;

private IdentityWrapper(Handle<IdentityWrapper> recyclerHandle) {
private EntryWrapper(Handle<EntryWrapper> recyclerHandle) {
this.recyclerHandle = recyclerHandle;
}

static <K, V> IdentityWrapper<K, V> create(K key, V value) {
IdentityWrapper<K, V> identityWrapper = RECYCLER.get();
identityWrapper.key = key;
identityWrapper.value = value;
return identityWrapper;
static <K, V> EntryWrapper<K, V> create(K key, V value, long size) {
EntryWrapper<K, V> entryWrapper = RECYCLER.get();
synchronized (entryWrapper) {
entryWrapper.key = key;
entryWrapper.value = value;
entryWrapper.size = size;
}
return entryWrapper;
}

K getKey() {
synchronized K getKey() {
return key;
}

V getValue() {
synchronized V getValue(K key) {
if (this.key != key) {
return null;
}
return value;
}

synchronized long getSize() {
return size;
}

void recycle() {
key = null;
value = null;
size = 0;
recyclerHandle.recycle(this);
}

@Override
public boolean equals(Object o) {
// only match exact identity of the value
return this == o;
}

@Override
public int hashCode() {
return Objects.hashCode(key);
}
}

/**
Expand Down Expand Up @@ -181,9 +187,10 @@ public boolean put(Key key, Value value) {
if (!value.matchesKey(key)) {
throw new IllegalArgumentException("Value '" + value + "' does not match key '" + key + "'");
}
IdentityWrapper<Key, Value> newWrapper = IdentityWrapper.create(key, value);
long entrySize = weighter.getSize(value);
EntryWrapper<Key, Value> newWrapper = EntryWrapper.create(key, value, entrySize);
if (entries.putIfAbsent(key, newWrapper) == null) {
size.addAndGet(weighter.getSize(value));
this.size.addAndGet(entrySize);
return true;
} else {
// recycle the new wrapper as it was not used
Expand All @@ -207,15 +214,15 @@ public Value get(Key key) {
return getValue(key, entries.get(key));
}

private Value getValue(Key key, IdentityWrapper<Key, Value> valueWrapper) {
private Value getValue(Key key, EntryWrapper<Key, Value> valueWrapper) {
if (valueWrapper == null) {
return null;
} else {
if (valueWrapper.getKey() != key) {
Value value = valueWrapper.getValue(key);
if (value == null) {
// the wrapper has been recycled and contains another key
return null;
}
Value value = valueWrapper.getValue();
try {
value.retain();
} catch (IllegalReferenceCountException e) {
Expand Down Expand Up @@ -247,7 +254,7 @@ public Collection<Value> getRange(Key first, Key last) {
List<Value> values = new ArrayList();

// Return the values of the entries found in cache
for (Map.Entry<Key, IdentityWrapper<Key, Value>> entry : entries.subMap(first, true, last, true).entrySet()) {
for (Map.Entry<Key, EntryWrapper<Key, Value>> entry : entries.subMap(first, true, last, true).entrySet()) {
Value value = getValue(entry.getKey(), entry.getValue());
if (value != null) {
values.add(value);
Expand All @@ -266,9 +273,9 @@ public Collection<Value> getRange(Key first, Key last) {
*/
public Pair<Integer, Long> removeRange(Key first, Key last, boolean lastInclusive) {
RemovalCounters counters = RemovalCounters.create();
Map<Key, IdentityWrapper<Key, Value>> subMap = entries.subMap(first, true, last, lastInclusive);
for (Map.Entry<Key, IdentityWrapper<Key, Value>> entry : subMap.entrySet()) {
removeEntry(entry, counters);
Map<Key, EntryWrapper<Key, Value>> subMap = entries.subMap(first, true, last, lastInclusive);
for (Map.Entry<Key, EntryWrapper<Key, Value>> entry : subMap.entrySet()) {
removeEntry(entry, counters, true);
}
return handleRemovalResult(counters);
}
Expand All @@ -279,36 +286,76 @@ enum RemoveEntryResult {
BREAK_LOOP;
}

private RemoveEntryResult removeEntry(Map.Entry<Key, IdentityWrapper<Key, Value>> entry, RemovalCounters counters) {
return removeEntry(entry, counters, (x) -> true);
private RemoveEntryResult removeEntry(Map.Entry<Key, EntryWrapper<Key, Value>> entry, RemovalCounters counters,
boolean skipInvalid) {
return removeEntry(entry, counters, skipInvalid, x -> true);
}

private RemoveEntryResult removeEntry(Map.Entry<Key, IdentityWrapper<Key, Value>> entry, RemovalCounters counters,
Predicate<Value> removeCondition) {
private RemoveEntryResult removeEntry(Map.Entry<Key, EntryWrapper<Key, Value>> entry, RemovalCounters counters,
boolean skipInvalid, Predicate<Value> removeCondition) {
Key key = entry.getKey();
IdentityWrapper<Key, Value> identityWrapper = entry.getValue();
if (identityWrapper.getKey() != key) {
// the wrapper has been recycled and contains another key
EntryWrapper<Key, Value> entryWrapper = entry.getValue();
Value value = entryWrapper.getValue(key);
if (value == null) {
// the wrapper has already been recycled and contains another key
if (!skipInvalid) {
EntryWrapper<Key, Value> removed = entries.remove(key);
if (removed != null) {
// log and remove the entry without releasing the value
log.info("Key {} does not match the entry's value wrapper's key {}, removed entry by key without "
+ "releasing the value", key, entryWrapper.getKey());
counters.entryRemoved(removed.getSize());
return RemoveEntryResult.ENTRY_REMOVED;
}
}
return RemoveEntryResult.CONTINUE_LOOP;
}
Value value = identityWrapper.getValue();
try {
// add extra retain to avoid value being released while we are removing it
value.retain();
} catch (IllegalReferenceCountException e) {
// Value was already released
if (!skipInvalid) {
// remove the specific entry without releasing the value
if (entries.remove(key, entryWrapper)) {
log.info("Value was already released for key {}, removed entry without releasing the value", key);
counters.entryRemoved(entryWrapper.getSize());
return RemoveEntryResult.ENTRY_REMOVED;
}
}
return RemoveEntryResult.CONTINUE_LOOP;
}
if (!value.matchesKey(key)) {
// this is unexpected since the IdentityWrapper.getValue(key) already checked that the value matches the key
log.warn("Unexpected race condition. Value {} does not match the key {}. Removing entry.", value, key);
}
try {
if (!removeCondition.test(value)) {
return RemoveEntryResult.BREAK_LOOP;
}
// check that the value hasn't been recycled in between
// there should be at least 2 references since this method adds one and the cache should have one
// it is valid that the value contains references even after the key has been removed from the cache
if (value.refCnt() > 1 && value.matchesKey(key) && entries.remove(key, identityWrapper)) {
identityWrapper.recycle();
counters.entryRemoved(weighter.getSize(value));
if (!skipInvalid) {
// remove the specific entry
boolean entryRemoved = entries.remove(key, entryWrapper);
if (entryRemoved) {
counters.entryRemoved(entryWrapper.getSize());
// check that the value hasn't been recycled in between
// there should be at least 2 references since this method adds one and the cache should have
// one reference. it is valid that the value contains references even after the key has been
// removed from the cache
if (value.refCnt() > 1) {
entryWrapper.recycle();
// remove the cache reference
value.release();
} else {
log.info("Unexpected refCnt {} for key {}, removed entry without releasing the value",
value.refCnt(), key);
}
}
} else if (skipInvalid && value.refCnt() > 1 && entries.remove(key, entryWrapper)) {
// when skipInvalid is true, we don't remove the entry if it doesn't match matches the key
// or the refCnt is invalid
counters.entryRemoved(entryWrapper.getSize());
entryWrapper.recycle();
// remove the cache reference
value.release();
}
Expand All @@ -334,12 +381,12 @@ private Pair<Integer, Long> handleRemovalResult(RemovalCounters counters) {
public Pair<Integer, Long> evictLeastAccessedEntries(long minSize) {
checkArgument(minSize > 0);
RemovalCounters counters = RemovalCounters.create();
while (counters.removedSize < minSize) {
Map.Entry<Key, IdentityWrapper<Key, Value>> entry = entries.firstEntry();
while (counters.removedSize < minSize && !Thread.currentThread().isInterrupted()) {
Map.Entry<Key, EntryWrapper<Key, Value>> entry = entries.firstEntry();
if (entry == null) {
break;
}
removeEntry(entry, counters);
removeEntry(entry, counters, false);
}
return handleRemovalResult(counters);
}
Expand All @@ -351,12 +398,12 @@ public Pair<Integer, Long> evictLeastAccessedEntries(long minSize) {
*/
public Pair<Integer, Long> evictLEntriesBeforeTimestamp(long maxTimestamp) {
RemovalCounters counters = RemovalCounters.create();
while (true) {
Map.Entry<Key, IdentityWrapper<Key, Value>> entry = entries.firstEntry();
while (!Thread.currentThread().isInterrupted()) {
Map.Entry<Key, EntryWrapper<Key, Value>> entry = entries.firstEntry();
if (entry == null) {
break;
}
if (removeEntry(entry, counters, value -> timestampExtractor.getTimestamp(value) <= maxTimestamp)
if (removeEntry(entry, counters, false, value -> timestampExtractor.getTimestamp(value) <= maxTimestamp)
== RemoveEntryResult.BREAK_LOOP) {
break;
}
Expand All @@ -382,12 +429,12 @@ public long getSize() {
*/
public Pair<Integer, Long> clear() {
RemovalCounters counters = RemovalCounters.create();
while (true) {
Map.Entry<Key, IdentityWrapper<Key, Value>> entry = entries.firstEntry();
while (!Thread.currentThread().isInterrupted()) {
Map.Entry<Key, EntryWrapper<Key, Value>> entry = entries.firstEntry();
if (entry == null) {
break;
}
removeEntry(entry, counters);
removeEntry(entry, counters, false);
}
return handleRemovalResult(counters);
}
Expand Down Expand Up @@ -421,5 +468,4 @@ public long getSize(Value value) {
return 1;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,9 @@ public void cacheSizeUpdate() throws Exception {
}

cacheManager.removeEntryCache(ml1.getName());
assertTrue(cacheManager.getSize() > 0);
assertEquals(factory2.getMbean().getCacheInsertedEntriesCount(), 20);
assertEquals(factory2.getMbean().getCacheEntriesCount(), 0);
assertEquals(0, cacheManager.getSize());
assertEquals(factory2.getMbean().getCacheEvictedEntriesCount(), 20);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import lombok.Cleanup;
import lombok.Data;
import org.apache.commons.lang3.tuple.Pair;
import org.awaitility.Awaitility;
import org.testng.annotations.Test;

public class RangeCacheTest {

@Data
class RefString extends AbstractReferenceCounted implements RangeCache.ValueWithKeyValidation<Integer> {
String s;
Integer matchingKey;
Expand Down Expand Up @@ -288,15 +291,21 @@ public void evictions() {
@Test
public void testPutWhileClearIsCalledConcurrently() {
RangeCache<Integer, RefString> cache = new RangeCache<>(value -> value.s.length(), x -> 0);
int numberOfThreads = 4;
int numberOfThreads = 8;
@Cleanup("shutdownNow")
ScheduledExecutorService executor = Executors.newScheduledThreadPool(numberOfThreads);
for (int i = 0; i < numberOfThreads; i++) {
executor.scheduleWithFixedDelay(cache::clear, 0, 1, TimeUnit.MILLISECONDS);
}
for (int i = 0; i < 100000; i++) {
for (int i = 0; i < 200000; i++) {
cache.put(i, new RefString(String.valueOf(i)));
}
executor.shutdown();
// ensure that no clear operation got into endless loop
Awaitility.await().untilAsserted(() -> assertTrue(executor.isTerminated()));
// ensure that clear can be called and all entries are removed
cache.clear();
assertEquals(cache.getNumberOfEntries(), 0);
}

@Test
Expand All @@ -307,4 +316,26 @@ public void testPutSameObj() {
assertTrue(cache.put(0, s0));
assertFalse(cache.put(0, s0));
}

@Test
public void testRemoveEntryWithInvalidRefCount() {
RangeCache<Integer, RefString> cache = new RangeCache<>(value -> value.s.length(), x -> 0);
RefString value = new RefString("1");
cache.put(1, value);
// release the value to make the reference count invalid
value.release();
cache.clear();
assertEquals(cache.getNumberOfEntries(), 0);
}

@Test
public void testRemoveEntryWithInvalidMatchingKey() {
RangeCache<Integer, RefString> cache = new RangeCache<>(value -> value.s.length(), x -> 0);
RefString value = new RefString("1");
cache.put(1, value);
// change the matching key to make it invalid
value.setMatchingKey(123);
cache.clear();
assertEquals(cache.getNumberOfEntries(), 0);
}
}

0 comments on commit 99aed3a

Please sign in to comment.