Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[improve][ml] RangeCache refactoring: test race conditions and prevent endless loops #22814

Merged
merged 8 commits into from
May 31, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -28,32 +28,36 @@
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentNavigableMap;
import java.util.concurrent.ConcurrentSkipListMap;
import java.util.concurrent.atomic.AtomicLong;
import lombok.extern.slf4j.Slf4j;
import org.apache.bookkeeper.mledger.util.RangeCache.ValueWithKeyValidation;
import org.apache.commons.lang3.tuple.Pair;

/**
* Special type of cache where get() and delete() operations can be done over a range of keys.
* The implementation avoids locks and synchronization and relies on ConcurrentSkipListMap for storing the entries.
* Since there is no locks, there is a need to have a way to ensure that a single entry in the cache is removed
* exactly once. Removing an entry multiple times would result in the entries of the cache getting released too
* while they could still be in use.
* The implementation avoids locks and synchronization by relying on ConcurrentSkipListMap for storing the entries.
* Since there are no locks, it's necessary to ensure that a single entry in the cache is removed exactly once.
* Removing an entry multiple times could result in the entries of the cache being released multiple times,
* even while they are still in use. This is prevented by using a custom wrapper around the value to store in the map
* that ensures that the value is removed from the map only if the exact same instance is present in the map.
* There's also a check that ensures that the value matches the key. This is used to detect races without impacting
* consistency.
*
* @param <Key>
* Cache key. Needs to be Comparable
* @param <Value>
* Cache value
*/
@Slf4j
public class RangeCache<Key extends Comparable<Key>, Value extends ValueWithKeyValidation<Key>> {
public interface ValueWithKeyValidation<T> extends ReferenceCounted {
boolean matchesKey(T key);
}

// Map from key to nodes inside the linked list
private final ConcurrentNavigableMap<Key, IdentityWrapper<Key, Value>> entries;
private final ConcurrentNavigableMap<Key, EntryWrapper<Key, Value>> entries;
private AtomicLong size; // Total size of values stored in cache
private final Weighter<Value> weighter; // Weighter object used to extract the size from values
private final TimestampExtractor<Value> timestampExtractor; // Extract the timestamp associated with a value
Expand All @@ -63,51 +67,53 @@ public interface ValueWithKeyValidation<T> extends ReferenceCounted {
* the map by calling the {@link Map#remove(Object, Object)} method. Certain race conditions could result in the
* wrong value being removed from the map. The instances of this class are recycled to avoid creating new objects.
*/
private static class IdentityWrapper<K, V> {
private final Handle<IdentityWrapper> recyclerHandle;
private static final Recycler<IdentityWrapper> RECYCLER = new Recycler<IdentityWrapper>() {
private static class EntryWrapper<K, V> {
private final Handle<EntryWrapper> recyclerHandle;
private static final Recycler<EntryWrapper> RECYCLER = new Recycler<EntryWrapper>() {
@Override
protected IdentityWrapper newObject(Handle<IdentityWrapper> recyclerHandle) {
return new IdentityWrapper(recyclerHandle);
protected EntryWrapper newObject(Handle<EntryWrapper> recyclerHandle) {
return new EntryWrapper(recyclerHandle);
}
};
private K key;
private V value;
long size;

private IdentityWrapper(Handle<IdentityWrapper> recyclerHandle) {
private EntryWrapper(Handle<EntryWrapper> recyclerHandle) {
this.recyclerHandle = recyclerHandle;
}

static <K, V> IdentityWrapper<K, V> create(K key, V value) {
IdentityWrapper<K, V> identityWrapper = RECYCLER.get();
identityWrapper.key = key;
identityWrapper.value = value;
return identityWrapper;
static <K, V> EntryWrapper<K, V> create(K key, V value, long size) {
EntryWrapper<K, V> entryWrapper = RECYCLER.get();
synchronized (entryWrapper) {
entryWrapper.key = key;
entryWrapper.value = value;
entryWrapper.size = size;
}
return entryWrapper;
}

K getKey() {
synchronized K getKey() {
return key;
}

V getValue() {
synchronized V getValue(K key) {
if (this.key != key) {
return null;
}
return value;
}

synchronized long getSize() {
return size;
}

void recycle() {
key = null;
value = null;
size = 0;
recyclerHandle.recycle(this);
}

@Override
public boolean equals(Object o) {
// only match exact identity of the value
return this == o;
}

@Override
public int hashCode() {
return Objects.hashCode(key);
}
}

/**
Expand Down Expand Up @@ -181,9 +187,10 @@ public boolean put(Key key, Value value) {
if (!value.matchesKey(key)) {
throw new IllegalArgumentException("Value '" + value + "' does not match key '" + key + "'");
}
IdentityWrapper<Key, Value> newWrapper = IdentityWrapper.create(key, value);
long entrySize = weighter.getSize(value);
EntryWrapper<Key, Value> newWrapper = EntryWrapper.create(key, value, entrySize);
if (entries.putIfAbsent(key, newWrapper) == null) {
size.addAndGet(weighter.getSize(value));
this.size.addAndGet(entrySize);
return true;
} else {
// recycle the new wrapper as it was not used
Expand All @@ -207,15 +214,15 @@ public Value get(Key key) {
return getValue(key, entries.get(key));
}

private Value getValue(Key key, IdentityWrapper<Key, Value> valueWrapper) {
private Value getValue(Key key, EntryWrapper<Key, Value> valueWrapper) {
if (valueWrapper == null) {
return null;
} else {
if (valueWrapper.getKey() != key) {
Value value = valueWrapper.getValue(key);
if (value == null) {
// the wrapper has been recycled and contains another key
return null;
}
Value value = valueWrapper.getValue();
try {
value.retain();
} catch (IllegalReferenceCountException e) {
Expand Down Expand Up @@ -247,7 +254,7 @@ public Collection<Value> getRange(Key first, Key last) {
List<Value> values = new ArrayList();

// Return the values of the entries found in cache
for (Map.Entry<Key, IdentityWrapper<Key, Value>> entry : entries.subMap(first, true, last, true).entrySet()) {
for (Map.Entry<Key, EntryWrapper<Key, Value>> entry : entries.subMap(first, true, last, true).entrySet()) {
Value value = getValue(entry.getKey(), entry.getValue());
if (value != null) {
values.add(value);
Expand All @@ -266,9 +273,9 @@ public Collection<Value> getRange(Key first, Key last) {
*/
public Pair<Integer, Long> removeRange(Key first, Key last, boolean lastInclusive) {
RemovalCounters counters = RemovalCounters.create();
Map<Key, IdentityWrapper<Key, Value>> subMap = entries.subMap(first, true, last, lastInclusive);
for (Map.Entry<Key, IdentityWrapper<Key, Value>> entry : subMap.entrySet()) {
removeEntry(entry, counters);
Map<Key, EntryWrapper<Key, Value>> subMap = entries.subMap(first, true, last, lastInclusive);
for (Map.Entry<Key, EntryWrapper<Key, Value>> entry : subMap.entrySet()) {
removeEntry(entry, counters, true);
}
return handleRemovalResult(counters);
}
Expand All @@ -279,36 +286,76 @@ enum RemoveEntryResult {
BREAK_LOOP;
}

private RemoveEntryResult removeEntry(Map.Entry<Key, IdentityWrapper<Key, Value>> entry, RemovalCounters counters) {
return removeEntry(entry, counters, (x) -> true);
private RemoveEntryResult removeEntry(Map.Entry<Key, EntryWrapper<Key, Value>> entry, RemovalCounters counters,
boolean skipInvalid) {
return removeEntry(entry, counters, skipInvalid, x -> true);
}

private RemoveEntryResult removeEntry(Map.Entry<Key, IdentityWrapper<Key, Value>> entry, RemovalCounters counters,
Predicate<Value> removeCondition) {
private RemoveEntryResult removeEntry(Map.Entry<Key, EntryWrapper<Key, Value>> entry, RemovalCounters counters,
boolean skipInvalid, Predicate<Value> removeCondition) {
Key key = entry.getKey();
IdentityWrapper<Key, Value> identityWrapper = entry.getValue();
if (identityWrapper.getKey() != key) {
// the wrapper has been recycled and contains another key
EntryWrapper<Key, Value> entryWrapper = entry.getValue();
Value value = entryWrapper.getValue(key);
if (value == null) {
// the wrapper has already been recycled and contains another key
if (!skipInvalid) {
EntryWrapper<Key, Value> removed = entries.remove(key);
if (removed != null) {
// log and remove the entry without releasing the value
log.info("Key {} does not match the entry's value wrapper's key {}, removed entry by key without "
+ "releasing the value", key, entryWrapper.getKey());
counters.entryRemoved(removed.getSize());
return RemoveEntryResult.ENTRY_REMOVED;
}
}
return RemoveEntryResult.CONTINUE_LOOP;
}
Value value = identityWrapper.getValue();
try {
// add extra retain to avoid value being released while we are removing it
value.retain();
} catch (IllegalReferenceCountException e) {
// Value was already released
if (!skipInvalid) {
// remove the specific entry without releasing the value
if (entries.remove(key, entryWrapper)) {
log.info("Value was already released for key {}, removed entry without releasing the value", key);
counters.entryRemoved(entryWrapper.getSize());
return RemoveEntryResult.ENTRY_REMOVED;
}
}
return RemoveEntryResult.CONTINUE_LOOP;
}
if (!value.matchesKey(key)) {
// this is unexpected since the IdentityWrapper.getValue(key) already checked that the value matches the key
log.warn("Unexpected race condition. Value {} does not match the key {}. Removing entry.", value, key);
}
try {
if (!removeCondition.test(value)) {
return RemoveEntryResult.BREAK_LOOP;
}
// check that the value hasn't been recycled in between
// there should be at least 2 references since this method adds one and the cache should have one
// it is valid that the value contains references even after the key has been removed from the cache
if (value.refCnt() > 1 && value.matchesKey(key) && entries.remove(key, identityWrapper)) {
identityWrapper.recycle();
counters.entryRemoved(weighter.getSize(value));
if (!skipInvalid) {
// remove the specific entry
boolean entryRemoved = entries.remove(key, entryWrapper);
if (entryRemoved) {
counters.entryRemoved(entryWrapper.getSize());
// check that the value hasn't been recycled in between
// there should be at least 2 references since this method adds one and the cache should have
// one reference. it is valid that the value contains references even after the key has been
// removed from the cache
if (value.refCnt() > 1) {
entryWrapper.recycle();
// remove the cache reference
value.release();
} else {
log.info("Unexpected refCnt {} for key {}, removed entry without releasing the value",
value.refCnt(), key);
}
}
} else if (skipInvalid && value.refCnt() > 1 && entries.remove(key, entryWrapper)) {
// when skipInvalid is true, we don't remove the entry if it doesn't match matches the key
// or the refCnt is invalid
counters.entryRemoved(entryWrapper.getSize());
entryWrapper.recycle();
// remove the cache reference
value.release();
}
Expand All @@ -334,12 +381,12 @@ private Pair<Integer, Long> handleRemovalResult(RemovalCounters counters) {
public Pair<Integer, Long> evictLeastAccessedEntries(long minSize) {
checkArgument(minSize > 0);
RemovalCounters counters = RemovalCounters.create();
while (counters.removedSize < minSize) {
Map.Entry<Key, IdentityWrapper<Key, Value>> entry = entries.firstEntry();
while (counters.removedSize < minSize && !Thread.currentThread().isInterrupted()) {
Map.Entry<Key, EntryWrapper<Key, Value>> entry = entries.firstEntry();
if (entry == null) {
break;
}
removeEntry(entry, counters);
removeEntry(entry, counters, false);
}
return handleRemovalResult(counters);
}
Expand All @@ -351,12 +398,12 @@ public Pair<Integer, Long> evictLeastAccessedEntries(long minSize) {
*/
public Pair<Integer, Long> evictLEntriesBeforeTimestamp(long maxTimestamp) {
RemovalCounters counters = RemovalCounters.create();
while (true) {
Map.Entry<Key, IdentityWrapper<Key, Value>> entry = entries.firstEntry();
while (!Thread.currentThread().isInterrupted()) {
Map.Entry<Key, EntryWrapper<Key, Value>> entry = entries.firstEntry();
if (entry == null) {
break;
}
if (removeEntry(entry, counters, value -> timestampExtractor.getTimestamp(value) <= maxTimestamp)
if (removeEntry(entry, counters, false, value -> timestampExtractor.getTimestamp(value) <= maxTimestamp)
== RemoveEntryResult.BREAK_LOOP) {
break;
}
Expand All @@ -382,12 +429,12 @@ public long getSize() {
*/
public Pair<Integer, Long> clear() {
RemovalCounters counters = RemovalCounters.create();
while (true) {
Map.Entry<Key, IdentityWrapper<Key, Value>> entry = entries.firstEntry();
while (!Thread.currentThread().isInterrupted()) {
Map.Entry<Key, EntryWrapper<Key, Value>> entry = entries.firstEntry();
if (entry == null) {
break;
}
removeEntry(entry, counters);
removeEntry(entry, counters, false);
}
return handleRemovalResult(counters);
}
Expand Down Expand Up @@ -421,5 +468,4 @@ public long getSize(Value value) {
return 1;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,9 @@ public void cacheSizeUpdate() throws Exception {
}

cacheManager.removeEntryCache(ml1.getName());
assertTrue(cacheManager.getSize() > 0);
assertEquals(factory2.getMbean().getCacheInsertedEntriesCount(), 20);
assertEquals(factory2.getMbean().getCacheEntriesCount(), 0);
assertEquals(0, cacheManager.getSize());
assertEquals(factory2.getMbean().getCacheEvictedEntriesCount(), 20);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import lombok.Cleanup;
import lombok.Data;
import org.apache.commons.lang3.tuple.Pair;
import org.awaitility.Awaitility;
import org.testng.annotations.Test;

public class RangeCacheTest {

@Data
class RefString extends AbstractReferenceCounted implements RangeCache.ValueWithKeyValidation<Integer> {
String s;
Integer matchingKey;
Expand Down Expand Up @@ -288,15 +291,21 @@ public void evictions() {
@Test
public void testPutWhileClearIsCalledConcurrently() {
RangeCache<Integer, RefString> cache = new RangeCache<>(value -> value.s.length(), x -> 0);
int numberOfThreads = 4;
int numberOfThreads = 8;
@Cleanup("shutdownNow")
ScheduledExecutorService executor = Executors.newScheduledThreadPool(numberOfThreads);
for (int i = 0; i < numberOfThreads; i++) {
executor.scheduleWithFixedDelay(cache::clear, 0, 1, TimeUnit.MILLISECONDS);
}
for (int i = 0; i < 100000; i++) {
for (int i = 0; i < 200000; i++) {
cache.put(i, new RefString(String.valueOf(i)));
}
executor.shutdown();
// ensure that no clear operation got into endless loop
Awaitility.await().untilAsserted(() -> assertTrue(executor.isTerminated()));
// ensure that clear can be called and all entries are removed
cache.clear();
assertEquals(cache.getNumberOfEntries(), 0);
}

@Test
Expand All @@ -307,4 +316,26 @@ public void testPutSameObj() {
assertTrue(cache.put(0, s0));
assertFalse(cache.put(0, s0));
}

@Test
public void testRemoveEntryWithInvalidRefCount() {
RangeCache<Integer, RefString> cache = new RangeCache<>(value -> value.s.length(), x -> 0);
RefString value = new RefString("1");
cache.put(1, value);
// release the value to make the reference count invalid
value.release();
cache.clear();
assertEquals(cache.getNumberOfEntries(), 0);
}

@Test
public void testRemoveEntryWithInvalidMatchingKey() {
RangeCache<Integer, RefString> cache = new RangeCache<>(value -> value.s.length(), x -> 0);
RefString value = new RefString("1");
cache.put(1, value);
// change the matching key to make it invalid
value.setMatchingKey(123);
cache.clear();
assertEquals(cache.getNumberOfEntries(), 0);
}
}
Loading