diff --git a/src/main/java/org/opensearch/ad/MemoryTracker.java b/src/main/java/org/opensearch/ad/MemoryTracker.java index 8904cd13f..fb2429a3f 100644 --- a/src/main/java/org/opensearch/ad/MemoryTracker.java +++ b/src/main/java/org/opensearch/ad/MemoryTracker.java @@ -34,6 +34,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.ad.common.exception.LimitExceededException; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.cluster.service.ClusterService; @@ -50,7 +51,7 @@ public class MemoryTracker { public enum Origin { SINGLE_ENTITY_DETECTOR, - MULTI_ENTITY_DETECTOR, + HC_DETECTOR, HISTORICAL_SINGLE_ENTITY_DETECTOR, } @@ -66,6 +67,7 @@ public enum Origin { // we observe threshold model uses a fixed size array and the size is the same private int thresholdModelBytes; private int sampleSize; + private ADCircuitBreakerService adCircuitBreakerService; /** * Constructor @@ -75,13 +77,15 @@ public enum Origin { * @param modelDesiredSizePercentage percentage of heap for the desired size of a model * @param clusterService Cluster service object * @param sampleSize The sample size used by stream samplers in a RCF forest + * @param adCircuitBreakerService Memory circuit breaker */ public MemoryTracker( JvmService jvmService, double modelMaxSizePercentage, double modelDesiredSizePercentage, ClusterService clusterService, - int sampleSize + int sampleSize, + ADCircuitBreakerService adCircuitBreakerService ) { this.totalMemoryBytes = 0; this.totalMemoryBytesByOrigin = new EnumMap(Origin.class); @@ -95,19 +99,19 @@ public MemoryTracker( .addSettingsUpdateConsumer(MODEL_MAX_SIZE_PERCENTAGE, it -> this.heapLimitBytes = (long) (heapSize * it)); this.thresholdModelBytes = 180_000; this.sampleSize = sampleSize; - } - - public synchronized boolean isHostingAllowed(String detectorId, RandomCutForest rcf) { - return canAllocateReserved(detectorId, estimateModelSize(rcf)); + this.adCircuitBreakerService = adCircuitBreakerService; } /** - * @param detectorId Detector Id, used in error message - * @param requiredBytes required bytes in memory - * @return whether there is memory required for AD + * This function derives from the old code: https://tinyurl.com/2eaabja6 + * + * @param detectorId Detector Id + * @param rcf Random cut forest model + * @return true if there is enough memory; otherwise throw LimitExceededException. */ - public synchronized boolean canAllocateReserved(String detectorId, long requiredBytes) { - if (reservedMemoryBytes + requiredBytes <= heapLimitBytes) { + public synchronized boolean isHostingAllowed(String detectorId, RandomCutForest rcf) { + long requiredBytes = estimateModelSize(rcf); + if (canAllocateReserved(requiredBytes)) { return true; } else { throw new LimitExceededException( @@ -124,12 +128,21 @@ public synchronized boolean canAllocateReserved(String detectorId, long required } /** - * Whether allocating memory is allowed + * @param requiredBytes required bytes to allocate + * @return whether there is enough memory for the required bytes. This is + * true when circuit breaker is closed and there is enough reserved memory. + */ + public synchronized boolean canAllocateReserved(long requiredBytes) { + return (false == adCircuitBreakerService.isOpen() && reservedMemoryBytes + requiredBytes <= heapLimitBytes); + } + + /** * @param bytes required bytes - * @return true if allowed; false otherwise + * @return whether there is enough memory for the required bytes. This is + * true when circuit breaker is closed and there is enough overall memory. */ public synchronized boolean canAllocate(long bytes) { - return totalMemoryBytes + bytes <= heapLimitBytes; + return false == adCircuitBreakerService.isOpen() && totalMemoryBytes + bytes <= heapLimitBytes; } public synchronized void consumeMemory(long memoryToConsume, boolean reserved, Origin origin) { @@ -243,8 +256,8 @@ public long getTotalMemoryBytes() { } /** - * In case of bugs/race conditions when allocating/releasing memory, sync used bytes - * infrequently by recomputing memory usage. + * In case of bugs/race conditions or users dyanmically changing dedicated/shared + * cache size, sync used bytes infrequently by recomputing memory usage. * @param origin Origin * @param totalBytes total bytes from recomputing * @param reservedBytes reserved bytes from recomputing @@ -256,6 +269,7 @@ public synchronized boolean syncMemoryState(Origin origin, long totalBytes, long if (totalBytes == recordedTotalBytes && reservedBytes == recordedReservedBytes) { return false; } + LOG .info( String diff --git a/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java b/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java index f9dfedf52..30e3c6433 100644 --- a/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java +++ b/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java @@ -106,7 +106,7 @@ public synchronized void add(ADTask adTask) { } checkRunningTaskLimit(); long neededCacheSize = calculateADTaskCacheSize(adTask); - if (!memoryTracker.canAllocateReserved(adTask.getDetectorId(), neededCacheSize)) { + if (!memoryTracker.canAllocateReserved(neededCacheSize)) { throw new LimitExceededException("No enough memory to run detector"); } memoryTracker.consumeMemory(neededCacheSize, true, HISTORICAL_SINGLE_ENTITY_DETECTOR); diff --git a/src/test/java/org/opensearch/ad/MemoryTrackerTests.java b/src/test/java/org/opensearch/ad/MemoryTrackerTests.java index 07f6c275e..61d085136 100644 --- a/src/test/java/org/opensearch/ad/MemoryTrackerTests.java +++ b/src/test/java/org/opensearch/ad/MemoryTrackerTests.java @@ -33,6 +33,7 @@ import java.util.Collections; import java.util.HashSet; +import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.ad.common.exception.LimitExceededException; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.AnomalyDetectorSettings; @@ -67,6 +68,7 @@ public class MemoryTrackerTests extends OpenSearchTestCase { double modelDesiredSizePercentage; JvmService jvmService; AnomalyDetector detector; + ADCircuitBreakerService circuitBreaker; @Override public void setUp() throws Exception { @@ -115,18 +117,35 @@ public void setUp() throws Exception { detector = mock(AnomalyDetector.class); when(detector.getEnabledFeatureIds()).thenReturn(Collections.singletonList("a")); when(detector.getShingleSize()).thenReturn(1); + + circuitBreaker = mock(ADCircuitBreakerService.class); + when(circuitBreaker.isOpen()).thenReturn(false); } private void setUpBigHeap() { ByteSizeValue value = new ByteSizeValue(largeHeapSize); when(mem.getHeapMax()).thenReturn(value); - tracker = new MemoryTracker(jvmService, modelMaxSizePercentage, modelDesiredSizePercentage, clusterService, rcfSampleSize); + tracker = new MemoryTracker( + jvmService, + modelMaxSizePercentage, + modelDesiredSizePercentage, + clusterService, + rcfSampleSize, + circuitBreaker + ); } private void setUpSmallHeap() { ByteSizeValue value = new ByteSizeValue(smallHeapSize); when(mem.getHeapMax()).thenReturn(value); - tracker = new MemoryTracker(jvmService, modelMaxSizePercentage, modelDesiredSizePercentage, clusterService, rcfSampleSize); + tracker = new MemoryTracker( + jvmService, + modelMaxSizePercentage, + modelDesiredSizePercentage, + clusterService, + rcfSampleSize, + circuitBreaker + ); } public void testEstimateModelSize() { @@ -145,10 +164,10 @@ public void testCanAllocate() { assertTrue(!tracker.canAllocate((long) (largeHeapSize * modelMaxPercen + 10))); long bytesToUse = 100_000; - tracker.consumeMemory(bytesToUse, false, MemoryTracker.Origin.MULTI_ENTITY_DETECTOR); + tracker.consumeMemory(bytesToUse, false, MemoryTracker.Origin.HC_DETECTOR); assertTrue(!tracker.canAllocate((long) (largeHeapSize * modelMaxPercen))); - tracker.releaseMemory(bytesToUse, false, MemoryTracker.Origin.MULTI_ENTITY_DETECTOR); + tracker.releaseMemory(bytesToUse, false, MemoryTracker.Origin.HC_DETECTOR); assertTrue(tracker.canAllocate((long) (largeHeapSize * modelMaxPercen))); } @@ -162,11 +181,11 @@ public void testMemoryToShed() { long bytesToUse = 100_000; assertEquals(bytesToUse, tracker.getHeapLimit()); assertEquals((long) (smallHeapSize * modelDesiredSizePercentage), tracker.getDesiredModelSize()); - tracker.consumeMemory(bytesToUse, false, MemoryTracker.Origin.MULTI_ENTITY_DETECTOR); - tracker.consumeMemory(bytesToUse, true, MemoryTracker.Origin.MULTI_ENTITY_DETECTOR); + tracker.consumeMemory(bytesToUse, false, MemoryTracker.Origin.HC_DETECTOR); + tracker.consumeMemory(bytesToUse, true, MemoryTracker.Origin.HC_DETECTOR); assertEquals(2 * bytesToUse, tracker.getTotalMemoryBytes()); assertEquals(bytesToUse, tracker.memoryToShed()); - assertTrue(!tracker.syncMemoryState(MemoryTracker.Origin.MULTI_ENTITY_DETECTOR, 2 * bytesToUse, bytesToUse)); + assertTrue(!tracker.syncMemoryState(MemoryTracker.Origin.HC_DETECTOR, 2 * bytesToUse, bytesToUse)); } } diff --git a/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java b/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java index 2f6be3e45..ba0da64fa 100644 --- a/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java +++ b/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java @@ -28,7 +28,6 @@ import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -88,7 +87,7 @@ public void tearDown() throws Exception { } public void testPutTask() throws IOException { - when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(true); + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true); ADTask adTask = TestHelpers.randomAdTask(); adTaskCacheManager.add(adTask); assertEquals(1, adTaskCacheManager.size()); @@ -104,7 +103,7 @@ public void testPutTask() throws IOException { } public void testPutDuplicateTask() throws IOException { - when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(true); + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true); ADTask adTask1 = TestHelpers.randomAdTask(); adTaskCacheManager.add(adTask1); assertEquals(1, adTaskCacheManager.size()); @@ -125,7 +124,7 @@ public void testPutDuplicateTask() throws IOException { } public void testPutTaskWithMemoryExceedLimit() { - when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(false); + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(false); LimitExceededException exception = expectThrows( LimitExceededException.class, () -> adTaskCacheManager.add(TestHelpers.randomAdTask()) @@ -134,7 +133,7 @@ public void testPutTaskWithMemoryExceedLimit() { } public void testThresholdModelTrained() throws IOException { - when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(true); + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true); ADTask adTask = TestHelpers.randomAdTask(); adTaskCacheManager.add(adTask); assertEquals(1, adTaskCacheManager.size()); @@ -147,7 +146,7 @@ public void testThresholdModelTrained() throws IOException { } public void testCancel() throws IOException { - when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(true); + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true); ADTask adTask = TestHelpers.randomAdTask(); adTaskCacheManager.add(adTask); assertEquals(1, adTaskCacheManager.size()); @@ -174,7 +173,7 @@ public void testRemoveTaskWhichNotExist() { } public void testExceedRunningTaskLimit() throws IOException { - when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(true); + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true); adTaskCacheManager.add(TestHelpers.randomAdTask()); adTaskCacheManager.add(TestHelpers.randomAdTask()); assertEquals(2, adTaskCacheManager.size());