Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only allocate models when circuit breaker is closed #74

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 30 additions & 16 deletions src/main/java/org/opensearch/ad/MemoryTracker.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ad.breaker.ADCircuitBreakerService;
import org.opensearch.ad.common.exception.LimitExceededException;
import org.opensearch.ad.model.AnomalyDetector;
import org.opensearch.cluster.service.ClusterService;
Expand All @@ -50,7 +51,7 @@ public class MemoryTracker {

public enum Origin {
SINGLE_ENTITY_DETECTOR,
MULTI_ENTITY_DETECTOR,
HC_DETECTOR,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems this caused build failure. Make sure all places changed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, will make sure that in the last PR with all of the code.

HISTORICAL_SINGLE_ENTITY_DETECTOR,
}

Expand All @@ -66,6 +67,7 @@ public enum Origin {
// we observe threshold model uses a fixed size array and the size is the same
private int thresholdModelBytes;
private int sampleSize;
private ADCircuitBreakerService adCircuitBreakerService;

/**
* Constructor
Expand All @@ -75,13 +77,15 @@ public enum Origin {
* @param modelDesiredSizePercentage percentage of heap for the desired size of a model
* @param clusterService Cluster service object
* @param sampleSize The sample size used by stream samplers in a RCF forest
* @param adCircuitBreakerService Memory circuit breaker
*/
public MemoryTracker(
JvmService jvmService,
double modelMaxSizePercentage,
double modelDesiredSizePercentage,
ClusterService clusterService,
int sampleSize
int sampleSize,
ADCircuitBreakerService adCircuitBreakerService
) {
this.totalMemoryBytes = 0;
this.totalMemoryBytesByOrigin = new EnumMap<Origin, Long>(Origin.class);
Expand All @@ -95,19 +99,19 @@ public MemoryTracker(
.addSettingsUpdateConsumer(MODEL_MAX_SIZE_PERCENTAGE, it -> this.heapLimitBytes = (long) (heapSize * it));
this.thresholdModelBytes = 180_000;
this.sampleSize = sampleSize;
}

public synchronized boolean isHostingAllowed(String detectorId, RandomCutForest rcf) {
return canAllocateReserved(detectorId, estimateModelSize(rcf));
this.adCircuitBreakerService = adCircuitBreakerService;
}

/**
* @param detectorId Detector Id, used in error message
* @param requiredBytes required bytes in memory
* @return whether there is memory required for AD
* This function derives from the old code: https://tinyurl.com/2eaabja6
*
* @param detectorId Detector Id
* @param rcf Random cut forest model
* @return true if there is enough memory; otherwise throw LimitExceededException.
*/
public synchronized boolean canAllocateReserved(String detectorId, long requiredBytes) {
if (reservedMemoryBytes + requiredBytes <= heapLimitBytes) {
public synchronized boolean isHostingAllowed(String detectorId, RandomCutForest rcf) {
long requiredBytes = estimateModelSize(rcf);
if (canAllocateReserved(requiredBytes)) {
return true;
} else {
throw new LimitExceededException(
Expand All @@ -124,12 +128,21 @@ public synchronized boolean canAllocateReserved(String detectorId, long required
}

/**
* Whether allocating memory is allowed
* @param requiredBytes required bytes to allocate
* @return whether there is enough memory for the required bytes. This is
* true when circuit breaker is closed and there is enough reserved memory.
*/
public synchronized boolean canAllocateReserved(long requiredBytes) {
return (false == adCircuitBreakerService.isOpen() && reservedMemoryBytes + requiredBytes <= heapLimitBytes);
}

/**
* @param bytes required bytes
* @return true if allowed; false otherwise
* @return whether there is enough memory for the required bytes. This is
* true when circuit breaker is closed and there is enough overall memory.
*/
public synchronized boolean canAllocate(long bytes) {
return totalMemoryBytes + bytes <= heapLimitBytes;
return false == adCircuitBreakerService.isOpen() && totalMemoryBytes + bytes <= heapLimitBytes;
}

public synchronized void consumeMemory(long memoryToConsume, boolean reserved, Origin origin) {
Expand Down Expand Up @@ -243,8 +256,8 @@ public long getTotalMemoryBytes() {
}

/**
* In case of bugs/race conditions when allocating/releasing memory, sync used bytes
* infrequently by recomputing memory usage.
* In case of bugs/race conditions or users dyanmically changing dedicated/shared
* cache size, sync used bytes infrequently by recomputing memory usage.
* @param origin Origin
* @param totalBytes total bytes from recomputing
* @param reservedBytes reserved bytes from recomputing
Expand All @@ -256,6 +269,7 @@ public synchronized boolean syncMemoryState(Origin origin, long totalBytes, long
if (totalBytes == recordedTotalBytes && reservedBytes == recordedReservedBytes) {
return false;
}

LOG
.info(
String
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public synchronized void add(ADTask adTask) {
}
checkRunningTaskLimit();
long neededCacheSize = calculateADTaskCacheSize(adTask);
if (!memoryTracker.canAllocateReserved(adTask.getDetectorId(), neededCacheSize)) {
if (!memoryTracker.canAllocateReserved(neededCacheSize)) {
throw new LimitExceededException("No enough memory to run detector");
}
memoryTracker.consumeMemory(neededCacheSize, true, HISTORICAL_SINGLE_ENTITY_DETECTOR);
Expand Down
33 changes: 26 additions & 7 deletions src/test/java/org/opensearch/ad/MemoryTrackerTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.Collections;
import java.util.HashSet;

import org.opensearch.ad.breaker.ADCircuitBreakerService;
import org.opensearch.ad.common.exception.LimitExceededException;
import org.opensearch.ad.model.AnomalyDetector;
import org.opensearch.ad.settings.AnomalyDetectorSettings;
Expand Down Expand Up @@ -67,6 +68,7 @@ public class MemoryTrackerTests extends OpenSearchTestCase {
double modelDesiredSizePercentage;
JvmService jvmService;
AnomalyDetector detector;
ADCircuitBreakerService circuitBreaker;

@Override
public void setUp() throws Exception {
Expand Down Expand Up @@ -115,18 +117,35 @@ public void setUp() throws Exception {
detector = mock(AnomalyDetector.class);
when(detector.getEnabledFeatureIds()).thenReturn(Collections.singletonList("a"));
when(detector.getShingleSize()).thenReturn(1);

circuitBreaker = mock(ADCircuitBreakerService.class);
when(circuitBreaker.isOpen()).thenReturn(false);
}

private void setUpBigHeap() {
ByteSizeValue value = new ByteSizeValue(largeHeapSize);
when(mem.getHeapMax()).thenReturn(value);
tracker = new MemoryTracker(jvmService, modelMaxSizePercentage, modelDesiredSizePercentage, clusterService, rcfSampleSize);
tracker = new MemoryTracker(
jvmService,
modelMaxSizePercentage,
modelDesiredSizePercentage,
clusterService,
rcfSampleSize,
circuitBreaker
);
}

private void setUpSmallHeap() {
ByteSizeValue value = new ByteSizeValue(smallHeapSize);
when(mem.getHeapMax()).thenReturn(value);
tracker = new MemoryTracker(jvmService, modelMaxSizePercentage, modelDesiredSizePercentage, clusterService, rcfSampleSize);
tracker = new MemoryTracker(
jvmService,
modelMaxSizePercentage,
modelDesiredSizePercentage,
clusterService,
rcfSampleSize,
circuitBreaker
);
}

public void testEstimateModelSize() {
Expand All @@ -145,10 +164,10 @@ public void testCanAllocate() {
assertTrue(!tracker.canAllocate((long) (largeHeapSize * modelMaxPercen + 10)));

long bytesToUse = 100_000;
tracker.consumeMemory(bytesToUse, false, MemoryTracker.Origin.MULTI_ENTITY_DETECTOR);
tracker.consumeMemory(bytesToUse, false, MemoryTracker.Origin.HC_DETECTOR);
assertTrue(!tracker.canAllocate((long) (largeHeapSize * modelMaxPercen)));

tracker.releaseMemory(bytesToUse, false, MemoryTracker.Origin.MULTI_ENTITY_DETECTOR);
tracker.releaseMemory(bytesToUse, false, MemoryTracker.Origin.HC_DETECTOR);
assertTrue(tracker.canAllocate((long) (largeHeapSize * modelMaxPercen)));
}

Expand All @@ -162,11 +181,11 @@ public void testMemoryToShed() {
long bytesToUse = 100_000;
assertEquals(bytesToUse, tracker.getHeapLimit());
assertEquals((long) (smallHeapSize * modelDesiredSizePercentage), tracker.getDesiredModelSize());
tracker.consumeMemory(bytesToUse, false, MemoryTracker.Origin.MULTI_ENTITY_DETECTOR);
tracker.consumeMemory(bytesToUse, true, MemoryTracker.Origin.MULTI_ENTITY_DETECTOR);
tracker.consumeMemory(bytesToUse, false, MemoryTracker.Origin.HC_DETECTOR);
tracker.consumeMemory(bytesToUse, true, MemoryTracker.Origin.HC_DETECTOR);
assertEquals(2 * bytesToUse, tracker.getTotalMemoryBytes());

assertEquals(bytesToUse, tracker.memoryToShed());
assertTrue(!tracker.syncMemoryState(MemoryTracker.Origin.MULTI_ENTITY_DETECTOR, 2 * bytesToUse, bytesToUse));
assertTrue(!tracker.syncMemoryState(MemoryTracker.Origin.HC_DETECTOR, 2 * bytesToUse, bytesToUse));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
Expand Down Expand Up @@ -88,7 +87,7 @@ public void tearDown() throws Exception {
}

public void testPutTask() throws IOException {
when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(true);
when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true);
ADTask adTask = TestHelpers.randomAdTask();
adTaskCacheManager.add(adTask);
assertEquals(1, adTaskCacheManager.size());
Expand All @@ -104,7 +103,7 @@ public void testPutTask() throws IOException {
}

public void testPutDuplicateTask() throws IOException {
when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(true);
when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true);
ADTask adTask1 = TestHelpers.randomAdTask();
adTaskCacheManager.add(adTask1);
assertEquals(1, adTaskCacheManager.size());
Expand All @@ -125,7 +124,7 @@ public void testPutDuplicateTask() throws IOException {
}

public void testPutTaskWithMemoryExceedLimit() {
when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(false);
when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(false);
LimitExceededException exception = expectThrows(
LimitExceededException.class,
() -> adTaskCacheManager.add(TestHelpers.randomAdTask())
Expand All @@ -134,7 +133,7 @@ public void testPutTaskWithMemoryExceedLimit() {
}

public void testThresholdModelTrained() throws IOException {
when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(true);
when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true);
ADTask adTask = TestHelpers.randomAdTask();
adTaskCacheManager.add(adTask);
assertEquals(1, adTaskCacheManager.size());
Expand All @@ -147,7 +146,7 @@ public void testThresholdModelTrained() throws IOException {
}

public void testCancel() throws IOException {
when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(true);
when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true);
ADTask adTask = TestHelpers.randomAdTask();
adTaskCacheManager.add(adTask);
assertEquals(1, adTaskCacheManager.size());
Expand All @@ -174,7 +173,7 @@ public void testRemoveTaskWhichNotExist() {
}

public void testExceedRunningTaskLimit() throws IOException {
when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(true);
when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true);
adTaskCacheManager.add(TestHelpers.randomAdTask());
adTaskCacheManager.add(TestHelpers.randomAdTask());
assertEquals(2, adTaskCacheManager.size());
Expand Down