Skip to content

Commit

Permalink
Fix bug in PriorityTracker
Browse files Browse the repository at this point in the history
PriorityTracker uses a ConcurrentSkipListSet to record priorities/frequencies of entities. I didn’t realize that ConcurrentSkipListSet.first() and ConcurrentSkipListSet.last() method throws exceptions when the set is empty. This PR adds an empty check.

Also, we want PriorityCache.selectUpdateCandidate to be side-effect free. Thus, this PR replaces the computeBufferIfAbsent method call with activeEnities.get as computeBufferIfAbsent has side effects by creating a CacheBuffer.

Testing done:
* created unit tests for related changes.
  • Loading branch information
kaituo committed Jul 6, 2021
1 parent c02cca1 commit 43d46dc
Show file tree
Hide file tree
Showing 8 changed files with 323 additions and 32 deletions.
4 changes: 2 additions & 2 deletions src/main/java/org/opensearch/ad/caching/CacheBuffer.java
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,8 @@ public boolean canReplaceWithinDetector(float priority) {
if (items.isEmpty()) {
return false;
}
Entry<String, Float> minPriorityItem = priorityTracker.getMinimumPriority();
return minPriorityItem != null && priority > minPriorityItem.getValue();
Optional<Entry<String, Float>> minPriorityItem = priorityTracker.getMinimumPriority();
return minPriorityItem.isPresent() && priority > minPriorityItem.get().getValue();
}

/**
Expand Down
39 changes: 28 additions & 11 deletions src/main/java/org/opensearch/ad/caching/PriorityCache.java
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,18 @@ public Pair<List<Entity>, List<Entity>> selectUpdateCandidate(
AnomalyDetector detector
) {
List<Entity> hotEntities = new ArrayList<>();
CacheBuffer buffer = computeBufferIfAbsent(detector, detectorId);
List<Entity> coldEntities = new ArrayList<>();

CacheBuffer buffer = activeEnities.get(detectorId);
if (buffer == null) {
// don't want to create side-effects by creating a CacheBuffer
// In current implementation, this branch is impossible as we call
// PriorityCache.get method before invoking this method. The
// PriorityCache.get method creates a CacheBuffer if not present.
// Since this method is public, need to deal with this case in case of misuse.
return Pair.of(hotEntities, coldEntities);
}

Iterator<Entity> cacheMissEntitiesIter = cacheMissEntities.iterator();
// current buffer's dedicated cache has free slots
while (cacheMissEntitiesIter.hasNext() && buffer.dedicatedCacheAvailable()) {
Expand Down Expand Up @@ -384,8 +395,6 @@ public Pair<List<Entity>, List<Entity>> selectUpdateCandidate(
// check if we can replace in other CacheBuffer
cacheMissEntitiesIter = otherBufferReplaceCandidates.iterator();

List<Entity> coldEntities = new ArrayList<>();

while (cacheMissEntitiesIter.hasNext()) {
// If two threads try to remove the same entity and add their own state, the 2nd remove
// returns null and only the first one succeeds.
Expand Down Expand Up @@ -487,12 +496,15 @@ private Triple<CacheBuffer, String, Float> canReplaceInSharedCache(CacheBuffer o
for (Map.Entry<String, CacheBuffer> entry : activeEnities.entrySet()) {
CacheBuffer buffer = entry.getValue();
if (buffer != originBuffer && buffer.canRemove()) {
Entry<String, Float> priorityEntry = buffer.getPriorityTracker().getMinimumScaledPriority();
float priority = priorityEntry.getValue();
Optional<Entry<String, Float>> priorityEntry = buffer.getPriorityTracker().getMinimumScaledPriority();
if (!priorityEntry.isPresent()) {
continue;
}
float priority = priorityEntry.get().getValue();
if (candidatePriority > priority && priority < minPriority) {
minPriority = priority;
minPriorityBuffer = buffer;
minPriorityEntityModelId = priorityEntry.getKey();
minPriorityEntityModelId = priorityEntry.get().getKey();
}
}
}
Expand Down Expand Up @@ -533,10 +545,13 @@ private void clearMemory() {
removalCandiates = new PriorityQueue<>((x, y) -> Float.compare(x.getLeft(), y.getLeft()));
for (Map.Entry<String, CacheBuffer> entry : activeEnities.entrySet()) {
CacheBuffer buffer = entry.getValue();
Entry<String, Float> priorityEntry = buffer.getPriorityTracker().getMinimumScaledPriority();
float priority = priorityEntry.getValue();
Optional<Entry<String, Float>> priorityEntry = buffer.getPriorityTracker().getMinimumScaledPriority();
if (!priorityEntry.isPresent()) {
continue;
}
float priority = priorityEntry.get().getValue();
if (buffer.canRemove()) {
removalCandiates.add(Triple.of(priority, buffer, priorityEntry.getKey()));
removalCandiates.add(Triple.of(priority, buffer, priorityEntry.get().getKey()));
}
}
}
Expand All @@ -552,8 +567,10 @@ private void clearMemory() {

if (minPriorityBuffer.canRemove()) {
// can remove another one
Entry<String, Float> priorityEntry = minPriorityBuffer.getPriorityTracker().getMinimumScaledPriority();
removalCandiates.add(Triple.of(priorityEntry.getValue(), minPriorityBuffer, priorityEntry.getKey()));
Optional<Entry<String, Float>> priorityEntry = minPriorityBuffer.getPriorityTracker().getMinimumScaledPriority();
if (priorityEntry.isPresent()) {
removalCandiates.add(Triple.of(priorityEntry.get().getValue(), minPriorityBuffer, priorityEntry.get().getKey()));
}
}
}

Expand Down
29 changes: 22 additions & 7 deletions src/main/java/org/opensearch/ad/caching/PriorityTracker.java
Original file line number Diff line number Diff line change
Expand Up @@ -184,28 +184,40 @@ public PriorityTracker(Clock clock, long intervalSecs, long landmarkEpoch, int m
/**
* Get the minimum priority entity and compute its scaled priority.
* Used to compare entity priorities among detectors.
* @return the minimum priority entity's ID and scaled priority
* @return the minimum priority entity's ID and scaled priority or Optional.empty
* if the priority list is empty
*/
public Entry<String, Float> getMinimumScaledPriority() {
public Optional<Entry<String, Float>> getMinimumScaledPriority() {
if (priorityList.isEmpty()) {
return Optional.empty();
}
PriorityNode smallest = priorityList.first();
return new SimpleImmutableEntry<>(smallest.key, getScaledPriority(smallest.priority));
return Optional.of(new SimpleImmutableEntry<>(smallest.key, getScaledPriority(smallest.priority)));
}

/**
* Get the minimum priority entity and compute its scaled priority.
* Used to compare entity priorities within the same detector.
* @return the minimum priority entity's ID and scaled priority
* @return the minimum priority entity's ID and scaled priority or Optional.empty
* if the priority list is empty
*/
public Entry<String, Float> getMinimumPriority() {
public Optional<Entry<String, Float>> getMinimumPriority() {
if (priorityList.isEmpty()) {
return Optional.empty();
}
PriorityNode smallest = priorityList.first();
return new SimpleImmutableEntry<>(smallest.key, smallest.priority);
return Optional.of(new SimpleImmutableEntry<>(smallest.key, smallest.priority));
}

/**
*
* @return the minimum priority entity's Id
* @return the minimum priority entity's Id or Optional.empty
* if the priority list is empty
*/
public Optional<String> getMinimumPriorityEntityId() {
if (priorityList.isEmpty()) {
return Optional.empty();
}
return Optional.of(priorityList).map(list -> list.first()).map(node -> node.key);
}

Expand All @@ -214,6 +226,9 @@ public Optional<String> getMinimumPriorityEntityId() {
* @return Get maximum priority entity's Id
*/
public Optional<String> getHighestPriorityEntityId() {
if (priorityList.isEmpty()) {
return Optional.empty();
}
return Optional.of(priorityList).map(list -> list.last()).map(node -> node.key);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,6 @@ public class AbstractCacheTest extends AbstractADTest {
@Before
public void setUp() throws Exception {
super.setUp();
modelId1 = "1";
modelId2 = "2";
modelId3 = "3";
modelId4 = "4";

detector = mock(AnomalyDetector.class);
detectorId = "123";
when(detector.getDetectorId()).thenReturn(detectorId);
Expand All @@ -66,6 +61,10 @@ public void setUp() throws Exception {
entity2 = Entity.createSingleAttributeEntity(detectorId, "attributeName1", "attributeVal2");
entity3 = Entity.createSingleAttributeEntity(detectorId, "attributeName1", "attributeVal3");
entity4 = Entity.createSingleAttributeEntity(detectorId, "attributeName1", "attributeVal4");
modelId1 = entity1.getModelId(detectorId).get();
modelId2 = entity2.getModelId(detectorId).get();
modelId3 = entity3.getModelId(detectorId).get();
modelId4 = entity4.getModelId(detectorId).get();

clock = mock(Clock.class);
when(clock.instant()).thenReturn(Instant.now());
Expand Down
7 changes: 4 additions & 3 deletions src/test/java/org/opensearch/ad/caching/CacheBufferTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.time.Instant;
import java.util.List;
import java.util.Map.Entry;
import java.util.Optional;

import org.mockito.ArgumentCaptor;
import org.opensearch.ad.MemoryTracker;
Expand All @@ -57,14 +58,14 @@ public void testRemovalCandidate() {
cacheBuffer.put(modelId1, modelState1);
cacheBuffer.put(modelId2, modelState2);
assertEquals(modelId1, cacheBuffer.get(modelId1).getModelId());
Entry<String, Float> removalCandidate = cacheBuffer.getPriorityTracker().getMinimumScaledPriority();
assertEquals(modelId2, removalCandidate.getKey());
Optional<Entry<String, Float>> removalCandidate = cacheBuffer.getPriorityTracker().getMinimumScaledPriority();
assertEquals(modelId2, removalCandidate.get().getKey());
cacheBuffer.remove();
cacheBuffer.put(modelId3, modelState3);
assertEquals(null, cacheBuffer.get(modelId2));
assertEquals(modelId3, cacheBuffer.get(modelId3).getModelId());
removalCandidate = cacheBuffer.getPriorityTracker().getMinimumScaledPriority();
assertEquals(modelId1, removalCandidate.getKey());
assertEquals(modelId1, removalCandidate.get().getKey());
cacheBuffer.remove(modelId1);
assertEquals(null, cacheBuffer.get(modelId1));
cacheBuffer.put(modelId4, modelState4);
Expand Down
166 changes: 165 additions & 1 deletion src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,17 @@

import java.time.Instant;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.junit.Before;
Expand All @@ -59,6 +63,7 @@
import org.opensearch.ad.ml.ModelManager.ModelType;
import org.opensearch.ad.ml.ModelState;
import org.opensearch.ad.model.AnomalyDetector;
import org.opensearch.ad.model.Entity;
import org.opensearch.ad.settings.AnomalyDetectorSettings;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
Expand Down Expand Up @@ -200,7 +205,6 @@ public void testSharedCache() {
assertEquals(2, cacheProvider.getActiveEntities(detectorId));

for (int i = 0; i < 10; i++) {
// put in dedicated cache
cacheProvider.get(modelId3, detector2);
}
modelState3 = new ModelState<>(
Expand Down Expand Up @@ -432,4 +436,164 @@ public void testFailedConcurrentMaintenance() throws InterruptedException {
// we should return here
return;
}

private void selectTestCommon(int entityFreq) {
for (int i = 0; i < entityFreq; i++) {
// bypass doorkeeper
cacheProvider.get(entity1.getModelId(detectorId).get(), detector);
}
Collection<Entity> cacheMissEntities = new ArrayList<>();
cacheMissEntities.add(entity1);
Pair<List<Entity>, List<Entity>> selectedAndOther = cacheProvider.selectUpdateCandidate(cacheMissEntities, detectorId, detector);
List<Entity> selected = selectedAndOther.getLeft();
assertEquals(1, selected.size());
assertEquals(entity1, selected.get(0));
assertEquals(0, selectedAndOther.getRight().size());
}

public void testSelectToDedicatedCache() {
selectTestCommon(2);
}

public void testSelectToSharedCache() {
for (int i = 0; i < 2; i++) {
// bypass doorkeeper
cacheProvider.get(entity2.getModelId(detectorId).get(), detector);
}
when(memoryTracker.canAllocate(anyLong())).thenReturn(true);

// fill in dedicated cache
cacheProvider.hostIfPossible(detector, modelState2);
selectTestCommon(2);
verify(memoryTracker, times(1)).canAllocate(anyLong());
}

public void testSelectToReplaceInCache() {
for (int i = 0; i < 2; i++) {
// bypass doorkeeper
cacheProvider.get(entity2.getModelId(detectorId).get(), detector);
}
when(memoryTracker.canAllocate(anyLong())).thenReturn(false);

// fill in dedicated cache
cacheProvider.hostIfPossible(detector, modelState2);
// make entity1 have enough priority to replace entity2
selectTestCommon(10);
verify(memoryTracker, times(1)).canAllocate(anyLong());
}

private void replaceInOtherCacheSetUp() {
Entity entity5 = Entity.createSingleAttributeEntity(detectorId2, "attributeName1", "attributeVal5");
Entity entity6 = Entity.createSingleAttributeEntity(detectorId2, "attributeName1", "attributeVal6");
ModelState<EntityModel> modelState5 = new ModelState<>(
new EntityModel(entity5, new ArrayDeque<>(), null, null),
entity5.getModelId(detectorId2).get(),
detectorId2,
ModelType.ENTITY.getName(),
clock,
0
);
ModelState<EntityModel> modelState6 = new ModelState<>(
new EntityModel(entity6, new ArrayDeque<>(), null, null),
entity6.getModelId(detectorId2).get(),
detectorId2,
ModelType.ENTITY.getName(),
clock,
0
);

for (int i = 0; i < 3; i++) {
// bypass doorkeeper and leave room for lower frequency entity in testSelectToCold
cacheProvider.get(entity5.getModelId(detectorId2).get(), detector2);
cacheProvider.get(entity6.getModelId(detectorId2).get(), detector2);
}
for (int i = 0; i < 10; i++) {
// entity1 cannot replace entity2 due to frequency
cacheProvider.get(entity2.getModelId(detectorId).get(), detector);
}
// put modelState5 in dedicated and modelState6 in shared cache
when(memoryTracker.canAllocate(anyLong())).thenReturn(true);
cacheProvider.hostIfPossible(detector2, modelState5);
cacheProvider.hostIfPossible(detector2, modelState6);

// fill in dedicated cache
cacheProvider.hostIfPossible(detector, modelState2);

// don't allow to use shared cache afterwards
when(memoryTracker.canAllocate(anyLong())).thenReturn(false);
}

public void testSelectToReplaceInOtherCache() {
replaceInOtherCacheSetUp();

// make entity1 have enough priority to replace entity2
selectTestCommon(10);
// once when deciding whether to host modelState6;
// once when calling selectUpdateCandidate on entity1
verify(memoryTracker, times(2)).canAllocate(anyLong());
}

public void testSelectToCold() {
replaceInOtherCacheSetUp();

for (int i = 0; i < 2; i++) {
// bypass doorkeeper
cacheProvider.get(entity1.getModelId(detectorId).get(), detector);
}
Collection<Entity> cacheMissEntities = new ArrayList<>();
cacheMissEntities.add(entity1);
Pair<List<Entity>, List<Entity>> selectedAndOther = cacheProvider.selectUpdateCandidate(cacheMissEntities, detectorId, detector);
List<Entity> cold = selectedAndOther.getRight();
assertEquals(1, cold.size());
assertEquals(entity1, cold.get(0));
assertEquals(0, selectedAndOther.getLeft().size());
}

/*
* Test the scenario:
* 1. A detector's buffer uses dedicated and shared memory
* 2. a new detector's buffer is created and triggers clearMemory (every new
* CacheBuffer creation will trigger it)
* 3. clearMemory found we can reclaim shared memory
*/
public void testClearMemory() {
for (int i = 0; i < 2; i++) {
// bypass doorkeeper
cacheProvider.get(entity2.getModelId(detectorId).get(), detector);
}

for (int i = 0; i < 10; i++) {
// bypass doorkeeper and make entity1 have higher frequency
cacheProvider.get(entity1.getModelId(detectorId).get(), detector);
}

// put modelState5 in dedicated and modelState6 in shared cache
when(memoryTracker.canAllocate(anyLong())).thenReturn(true);
cacheProvider.hostIfPossible(detector, modelState1);
cacheProvider.hostIfPossible(detector, modelState2);

// two entities get inserted to cache
assertTrue(null != cacheProvider.get(entity1.getModelId(detectorId).get(), detector));
assertTrue(null != cacheProvider.get(entity2.getModelId(detectorId).get(), detector));

Entity entity5 = Entity.createSingleAttributeEntity(detectorId2, "attributeName1", "attributeVal5");
when(memoryTracker.memoryToShed()).thenReturn(memoryPerEntity);
for (int i = 0; i < 2; i++) {
// bypass doorkeeper, CacheBuffer created, and trigger clearMemory
cacheProvider.get(entity5.getModelId(detectorId2).get(), detector2);
}

assertTrue(null != cacheProvider.get(entity1.getModelId(detectorId).get(), detector));
// entity 2 removed
assertTrue(null == cacheProvider.get(entity2.getModelId(detectorId).get(), detector));
assertTrue(null == cacheProvider.get(entity5.getModelId(detectorId2).get(), detector));
}

public void testSelectEmpty() {
Collection<Entity> cacheMissEntities = new ArrayList<>();
cacheMissEntities.add(entity1);
Pair<List<Entity>, List<Entity>> selectedAndOther = cacheProvider.selectUpdateCandidate(cacheMissEntities, detectorId, detector);
assertEquals(0, selectedAndOther.getLeft().size());
assertEquals(0, selectedAndOther.getRight().size());
}
}
Loading

0 comments on commit 43d46dc

Please sign in to comment.