From 4943a842e1eb48ac75c094f575547660f94b31fe Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Wed, 26 May 2021 12:45:17 -0700 Subject: [PATCH] Only allocate models when circuit breaker is closed MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously we may allocate memory when the circuit breaker is open but AD hasn’t used more than 10% of memory yet. This PR disallows any new model creation when the circuit breaker is open, which provides extra safety to prevent OOM. This PR also differentiates isHostingAllowed and canAllocateReserved. Previously, both of them throw exceptions when allocation is not allowed. But most of the time it is fine to return false and let callers decide what to do if we cannot allocate. isHostingAllowed needs an exception thrown due to the special way the single-stream detector code is written (can be changed). This PR keeps the original exception thrown behavior of isHostingAllowed, but changes canAllocateReserved to return false instead. Testing done: 1. Tested an open circuit breaker prevented model creation. --- .../java/org/opensearch/ad/MemoryTracker.java | 46 ++++++++++++------- .../ad/task/ADTaskCacheManager.java | 2 +- .../org/opensearch/ad/MemoryTrackerTests.java | 33 ++++++++++--- .../ad/task/ADTaskCacheManagerTests.java | 13 +++--- 4 files changed, 63 insertions(+), 31 deletions(-) diff --git a/src/main/java/org/opensearch/ad/MemoryTracker.java b/src/main/java/org/opensearch/ad/MemoryTracker.java index 8904cd13f..fb2429a3f 100644 --- a/src/main/java/org/opensearch/ad/MemoryTracker.java +++ b/src/main/java/org/opensearch/ad/MemoryTracker.java @@ -34,6 +34,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.ad.common.exception.LimitExceededException; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.cluster.service.ClusterService; @@ -50,7 +51,7 @@ public class MemoryTracker { public enum Origin { SINGLE_ENTITY_DETECTOR, - MULTI_ENTITY_DETECTOR, + HC_DETECTOR, HISTORICAL_SINGLE_ENTITY_DETECTOR, } @@ -66,6 +67,7 @@ public enum Origin { // we observe threshold model uses a fixed size array and the size is the same private int thresholdModelBytes; private int sampleSize; + private ADCircuitBreakerService adCircuitBreakerService; /** * Constructor @@ -75,13 +77,15 @@ public enum Origin { * @param modelDesiredSizePercentage percentage of heap for the desired size of a model * @param clusterService Cluster service object * @param sampleSize The sample size used by stream samplers in a RCF forest + * @param adCircuitBreakerService Memory circuit breaker */ public MemoryTracker( JvmService jvmService, double modelMaxSizePercentage, double modelDesiredSizePercentage, ClusterService clusterService, - int sampleSize + int sampleSize, + ADCircuitBreakerService adCircuitBreakerService ) { this.totalMemoryBytes = 0; this.totalMemoryBytesByOrigin = new EnumMap(Origin.class); @@ -95,19 +99,19 @@ public MemoryTracker( .addSettingsUpdateConsumer(MODEL_MAX_SIZE_PERCENTAGE, it -> this.heapLimitBytes = (long) (heapSize * it)); this.thresholdModelBytes = 180_000; this.sampleSize = sampleSize; - } - - public synchronized boolean isHostingAllowed(String detectorId, RandomCutForest rcf) { - return canAllocateReserved(detectorId, estimateModelSize(rcf)); + this.adCircuitBreakerService = adCircuitBreakerService; } /** - * @param detectorId Detector Id, used in error message - * @param requiredBytes required bytes in memory - * @return whether there is memory required for AD + * This function derives from the old code: https://tinyurl.com/2eaabja6 + * + * @param detectorId Detector Id + * @param rcf Random cut forest model + * @return true if there is enough memory; otherwise throw LimitExceededException. */ - public synchronized boolean canAllocateReserved(String detectorId, long requiredBytes) { - if (reservedMemoryBytes + requiredBytes <= heapLimitBytes) { + public synchronized boolean isHostingAllowed(String detectorId, RandomCutForest rcf) { + long requiredBytes = estimateModelSize(rcf); + if (canAllocateReserved(requiredBytes)) { return true; } else { throw new LimitExceededException( @@ -124,12 +128,21 @@ public synchronized boolean canAllocateReserved(String detectorId, long required } /** - * Whether allocating memory is allowed + * @param requiredBytes required bytes to allocate + * @return whether there is enough memory for the required bytes. This is + * true when circuit breaker is closed and there is enough reserved memory. + */ + public synchronized boolean canAllocateReserved(long requiredBytes) { + return (false == adCircuitBreakerService.isOpen() && reservedMemoryBytes + requiredBytes <= heapLimitBytes); + } + + /** * @param bytes required bytes - * @return true if allowed; false otherwise + * @return whether there is enough memory for the required bytes. This is + * true when circuit breaker is closed and there is enough overall memory. */ public synchronized boolean canAllocate(long bytes) { - return totalMemoryBytes + bytes <= heapLimitBytes; + return false == adCircuitBreakerService.isOpen() && totalMemoryBytes + bytes <= heapLimitBytes; } public synchronized void consumeMemory(long memoryToConsume, boolean reserved, Origin origin) { @@ -243,8 +256,8 @@ public long getTotalMemoryBytes() { } /** - * In case of bugs/race conditions when allocating/releasing memory, sync used bytes - * infrequently by recomputing memory usage. + * In case of bugs/race conditions or users dyanmically changing dedicated/shared + * cache size, sync used bytes infrequently by recomputing memory usage. * @param origin Origin * @param totalBytes total bytes from recomputing * @param reservedBytes reserved bytes from recomputing @@ -256,6 +269,7 @@ public synchronized boolean syncMemoryState(Origin origin, long totalBytes, long if (totalBytes == recordedTotalBytes && reservedBytes == recordedReservedBytes) { return false; } + LOG .info( String diff --git a/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java b/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java index f9dfedf52..30e3c6433 100644 --- a/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java +++ b/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java @@ -106,7 +106,7 @@ public synchronized void add(ADTask adTask) { } checkRunningTaskLimit(); long neededCacheSize = calculateADTaskCacheSize(adTask); - if (!memoryTracker.canAllocateReserved(adTask.getDetectorId(), neededCacheSize)) { + if (!memoryTracker.canAllocateReserved(neededCacheSize)) { throw new LimitExceededException("No enough memory to run detector"); } memoryTracker.consumeMemory(neededCacheSize, true, HISTORICAL_SINGLE_ENTITY_DETECTOR); diff --git a/src/test/java/org/opensearch/ad/MemoryTrackerTests.java b/src/test/java/org/opensearch/ad/MemoryTrackerTests.java index 07f6c275e..61d085136 100644 --- a/src/test/java/org/opensearch/ad/MemoryTrackerTests.java +++ b/src/test/java/org/opensearch/ad/MemoryTrackerTests.java @@ -33,6 +33,7 @@ import java.util.Collections; import java.util.HashSet; +import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.ad.common.exception.LimitExceededException; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.AnomalyDetectorSettings; @@ -67,6 +68,7 @@ public class MemoryTrackerTests extends OpenSearchTestCase { double modelDesiredSizePercentage; JvmService jvmService; AnomalyDetector detector; + ADCircuitBreakerService circuitBreaker; @Override public void setUp() throws Exception { @@ -115,18 +117,35 @@ public void setUp() throws Exception { detector = mock(AnomalyDetector.class); when(detector.getEnabledFeatureIds()).thenReturn(Collections.singletonList("a")); when(detector.getShingleSize()).thenReturn(1); + + circuitBreaker = mock(ADCircuitBreakerService.class); + when(circuitBreaker.isOpen()).thenReturn(false); } private void setUpBigHeap() { ByteSizeValue value = new ByteSizeValue(largeHeapSize); when(mem.getHeapMax()).thenReturn(value); - tracker = new MemoryTracker(jvmService, modelMaxSizePercentage, modelDesiredSizePercentage, clusterService, rcfSampleSize); + tracker = new MemoryTracker( + jvmService, + modelMaxSizePercentage, + modelDesiredSizePercentage, + clusterService, + rcfSampleSize, + circuitBreaker + ); } private void setUpSmallHeap() { ByteSizeValue value = new ByteSizeValue(smallHeapSize); when(mem.getHeapMax()).thenReturn(value); - tracker = new MemoryTracker(jvmService, modelMaxSizePercentage, modelDesiredSizePercentage, clusterService, rcfSampleSize); + tracker = new MemoryTracker( + jvmService, + modelMaxSizePercentage, + modelDesiredSizePercentage, + clusterService, + rcfSampleSize, + circuitBreaker + ); } public void testEstimateModelSize() { @@ -145,10 +164,10 @@ public void testCanAllocate() { assertTrue(!tracker.canAllocate((long) (largeHeapSize * modelMaxPercen + 10))); long bytesToUse = 100_000; - tracker.consumeMemory(bytesToUse, false, MemoryTracker.Origin.MULTI_ENTITY_DETECTOR); + tracker.consumeMemory(bytesToUse, false, MemoryTracker.Origin.HC_DETECTOR); assertTrue(!tracker.canAllocate((long) (largeHeapSize * modelMaxPercen))); - tracker.releaseMemory(bytesToUse, false, MemoryTracker.Origin.MULTI_ENTITY_DETECTOR); + tracker.releaseMemory(bytesToUse, false, MemoryTracker.Origin.HC_DETECTOR); assertTrue(tracker.canAllocate((long) (largeHeapSize * modelMaxPercen))); } @@ -162,11 +181,11 @@ public void testMemoryToShed() { long bytesToUse = 100_000; assertEquals(bytesToUse, tracker.getHeapLimit()); assertEquals((long) (smallHeapSize * modelDesiredSizePercentage), tracker.getDesiredModelSize()); - tracker.consumeMemory(bytesToUse, false, MemoryTracker.Origin.MULTI_ENTITY_DETECTOR); - tracker.consumeMemory(bytesToUse, true, MemoryTracker.Origin.MULTI_ENTITY_DETECTOR); + tracker.consumeMemory(bytesToUse, false, MemoryTracker.Origin.HC_DETECTOR); + tracker.consumeMemory(bytesToUse, true, MemoryTracker.Origin.HC_DETECTOR); assertEquals(2 * bytesToUse, tracker.getTotalMemoryBytes()); assertEquals(bytesToUse, tracker.memoryToShed()); - assertTrue(!tracker.syncMemoryState(MemoryTracker.Origin.MULTI_ENTITY_DETECTOR, 2 * bytesToUse, bytesToUse)); + assertTrue(!tracker.syncMemoryState(MemoryTracker.Origin.HC_DETECTOR, 2 * bytesToUse, bytesToUse)); } } diff --git a/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java b/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java index 2f6be3e45..ba0da64fa 100644 --- a/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java +++ b/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java @@ -28,7 +28,6 @@ import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -88,7 +87,7 @@ public void tearDown() throws Exception { } public void testPutTask() throws IOException { - when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(true); + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true); ADTask adTask = TestHelpers.randomAdTask(); adTaskCacheManager.add(adTask); assertEquals(1, adTaskCacheManager.size()); @@ -104,7 +103,7 @@ public void testPutTask() throws IOException { } public void testPutDuplicateTask() throws IOException { - when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(true); + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true); ADTask adTask1 = TestHelpers.randomAdTask(); adTaskCacheManager.add(adTask1); assertEquals(1, adTaskCacheManager.size()); @@ -125,7 +124,7 @@ public void testPutDuplicateTask() throws IOException { } public void testPutTaskWithMemoryExceedLimit() { - when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(false); + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(false); LimitExceededException exception = expectThrows( LimitExceededException.class, () -> adTaskCacheManager.add(TestHelpers.randomAdTask()) @@ -134,7 +133,7 @@ public void testPutTaskWithMemoryExceedLimit() { } public void testThresholdModelTrained() throws IOException { - when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(true); + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true); ADTask adTask = TestHelpers.randomAdTask(); adTaskCacheManager.add(adTask); assertEquals(1, adTaskCacheManager.size()); @@ -147,7 +146,7 @@ public void testThresholdModelTrained() throws IOException { } public void testCancel() throws IOException { - when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(true); + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true); ADTask adTask = TestHelpers.randomAdTask(); adTaskCacheManager.add(adTask); assertEquals(1, adTaskCacheManager.size()); @@ -174,7 +173,7 @@ public void testRemoveTaskWhichNotExist() { } public void testExceedRunningTaskLimit() throws IOException { - when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(true); + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true); adTaskCacheManager.add(TestHelpers.randomAdTask()); adTaskCacheManager.add(TestHelpers.randomAdTask()); assertEquals(2, adTaskCacheManager.size());