diff --git a/src/main/java/org/opensearch/ad/feature/CompositeRetriever.java b/src/main/java/org/opensearch/ad/feature/CompositeRetriever.java index 0d350f627..ea9054eca 100644 --- a/src/main/java/org/opensearch/ad/feature/CompositeRetriever.java +++ b/src/main/java/org/opensearch/ad/feature/CompositeRetriever.java @@ -54,7 +54,7 @@ * */ public class CompositeRetriever extends AbstractRetriever { - private static final String AGG_NAME_COMP = "comp_agg"; + public static final String AGG_NAME_COMP = "comp_agg"; private static final Logger LOG = LogManager.getLogger(CompositeRetriever.class); private final long dataStartEpoch; diff --git a/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java index e73f17b7c..45bad818b 100644 --- a/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java @@ -154,11 +154,11 @@ private Entity buildEntity(RestRequest request, String detectorId) throws IOExce if (entityName != null && entityValue != null) { // single-stream profile request: - // GET _opendistro/_anomaly_detection/detectors//_profile/init_progress?category_field=&entity= + // GET _plugins/_anomaly_detection/detectors//_profile/init_progress?category_field=&entity= return Entity.createSingleAttributeEntity(detectorId, entityName, entityValue); } else if (request.hasContent()) { /* HCAD profile request: - * GET _opendistro/_anomaly_detection/detectors//_profile/init_progress + * GET _plugins/_anomaly_detection/detectors//_profile/init_progress * { * "entity": [{ * "name": "clientip", diff --git a/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java b/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java index c35ab699e..f7c403e01 100644 --- a/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java @@ -31,6 +31,8 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY; import static org.opensearch.ad.settings.AnomalyDetectorSettings.PAGE_SIZE; @@ -39,26 +41,34 @@ import java.time.Clock; import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.NoSuchElementException; import java.util.Optional; import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; +import org.apache.commons.lang3.tuple.Pair; import org.junit.After; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; +import org.mockito.stubbing.Answer; import org.opensearch.Version; import org.opensearch.action.ActionListener; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; import org.opensearch.action.search.SearchPhaseExecutionException; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponse.Clusters; +import org.opensearch.action.search.SearchResponseSections; import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; @@ -71,7 +81,9 @@ import org.opensearch.ad.caching.EntityCache; import org.opensearch.ad.cluster.HashRing; import org.opensearch.ad.common.exception.EndRunException; +import org.opensearch.ad.common.exception.LimitExceededException; import org.opensearch.ad.constant.CommonErrorMessages; +import org.opensearch.ad.feature.CompositeRetriever; import org.opensearch.ad.feature.FeatureManager; import org.opensearch.ad.feature.SearchFeatureDao; import org.opensearch.ad.indices.AnomalyDetectionIndices; @@ -80,6 +92,7 @@ import org.opensearch.ad.ml.ModelPartitioner; import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.Entity; import org.opensearch.ad.model.IntervalTimeConfiguration; import org.opensearch.ad.ratelimit.CheckpointReadWorker; import org.opensearch.ad.ratelimit.ColdEntityWorker; @@ -102,11 +115,21 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; +import org.opensearch.search.aggregations.metrics.InternalMin; import org.opensearch.test.ClusterServiceUtils; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.Transport; import org.opensearch.transport.TransportException; import org.opensearch.transport.TransportInterceptor; +import org.opensearch.transport.TransportRequest; +import org.opensearch.transport.TransportRequestOptions; import org.opensearch.transport.TransportResponse; import org.opensearch.transport.TransportResponseHandler; import org.opensearch.transport.TransportService; @@ -176,6 +199,9 @@ public void setUp() throws Exception { return null; }).when(stateManager).getAnomalyDetector(anyString(), any(ActionListener.class)); + // AnomalyDetector detector = TestHelpers + // .randomAnomalyDetectorWithInterval(new IntervalTimeConfiguration(1, ChronoUnit.MINUTES), true, true); + settings = Settings.builder().put(AnomalyDetectorSettings.COOLDOWN_MINUTES.getKey(), TimeValue.timeValueMinutes(5)).build(); // make sure end time is larger enough than Clock.systemUTC().millis() to get PageIterator.hasNext() to pass @@ -192,6 +218,8 @@ public void setUp() throws Exception { featureQuery = mock(FeatureManager.class); + normalModelManager = mock(ModelManager.class); + normalModelPartitioner = mock(ModelPartitioner.class); hashRing = mock(HashRing.class); @@ -252,6 +280,7 @@ public void setUp() throws Exception { when(provider.get()).thenReturn(entityCache); when(entityCache.get(any(), any())) .thenReturn(MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build())); + when(entityCache.selectUpdateCandidate(any(), any(), any())).thenReturn(Pair.of(new ArrayList(), new ArrayList())); indexUtil = mock(AnomalyDetectionIndices.class); resultWriteQueue = mock(ResultWriteWorker.class); @@ -286,6 +315,7 @@ public void testColdStartEndRunException() { assertException(listener, EndRunException.class, CommonErrorMessages.INVALID_SEARCH_QUERY_MSG); } + // a handler that forwards response or exception received from network private TransportResponseHandler entityResultHandler(TransportResponseHandler handler) { return new TransportResponseHandler() { @Override @@ -353,16 +383,11 @@ private void setUpEntityResult() { threadPool ); - when(normalModelManager.score(any(), anyString(), any())).thenReturn(new ThresholdingResult(0, 1, 1)); + when(normalModelManager.getAnomalyResultForEntity(any(), any(), any(), any(), any())).thenReturn(new ThresholdingResult(0, 1, 1)); } - /** - * Test query error causes EndRunException but not end now - * @throws InterruptedException when the await are interrupted - * @throws IOException when failing to create anomaly detector - */ @SuppressWarnings("unchecked") - public void testQueryErrorEndRunNotNow() throws InterruptedException, IOException { + public void setUpNormlaStateManager() throws IOException { ClientUtil clientUtil = mock(ClientUtil.class); AnomalyDetector detector = TestHelpers @@ -401,6 +426,15 @@ public void testQueryErrorEndRunNotNow() throws InterruptedException, IOExceptio mockThreadPool, xContentRegistry() ); + } + + /** + * Test query error causes EndRunException but not end now + * @throws InterruptedException when the await are interrupted + * @throws IOException when failing to create anomaly detector + */ + public void testQueryErrorEndRunNotNow() throws InterruptedException, IOException { + setUpNormlaStateManager(); final CountDownLatch inProgressLatch = new CountDownLatch(1); @@ -438,4 +472,295 @@ public void testQueryErrorEndRunNotNow() throws InterruptedException, IOExceptio // not end now assertTrue(!((EndRunException) e).isEndNow()); } + + public void testIndexNotFound() throws InterruptedException, IOException { + setUpNormlaStateManager(); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + // make PageIterator.next return failure + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new IndexNotFoundException("", "")); + inProgressLatch.countDown(); + return null; + }).when(client).search(any(), any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.001); + + assertTrue(inProgressLatch.await(10000L, TimeUnit.MILLISECONDS)); + + PlainActionFuture listener2 = new PlainActionFuture<>(); + action.doExecute(null, request, listener2); + Exception e = expectThrows(EndRunException.class, () -> listener2.actionGet(10000L)); + assertThat( + "actual message: " + e.getMessage(), + e.getMessage(), + containsString(AnomalyResultTransportAction.TROUBLE_QUERYING_ERR_MSG) + ); + assertTrue(!((EndRunException) e).isEndNow()); + } + + public void testEmptyFeatures() throws InterruptedException { + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(createEmptyResponse()); + inProgressLatch.countDown(); + return null; + }).when(client).search(any(), any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.01); + + assertTrue(inProgressLatch.await(10000L, TimeUnit.MILLISECONDS)); + + PlainActionFuture listener2 = new PlainActionFuture<>(); + action.doExecute(null, request, listener2); + + AnomalyResultResponse response2 = listener2.actionGet(10000L); + assertEquals(Double.NaN, response2.getAnomalyGrade(), 0.01); + } + + /** + * + * @return an empty response + */ + private SearchResponse createEmptyResponse() { + CompositeAggregation emptyComposite = mock(CompositeAggregation.class); + when(emptyComposite.getName()).thenReturn(CompositeRetriever.AGG_NAME_COMP); + when(emptyComposite.afterKey()).thenReturn(null); + // empty bucket + when(emptyComposite.getBuckets()) + .thenAnswer((Answer>) invocation -> { return new ArrayList(); }); + Aggregations emptyAggs = new Aggregations(Collections.singletonList(emptyComposite)); + SearchResponseSections emptySections = new SearchResponseSections(SearchHits.empty(), emptyAggs, null, false, null, null, 1); + return new SearchResponse(emptySections, null, 1, 1, 0, 0, ShardSearchFailure.EMPTY_ARRAY, Clusters.EMPTY); + } + + private CountDownLatch setUpTransportInterceptor( + Function, TransportResponseHandler> interceptor + ) { + // set up a non-empty response + CompositeAggregation composite = mock(CompositeAggregation.class); + when(composite.getName()).thenReturn(CompositeRetriever.AGG_NAME_COMP); + Map afterKey = new HashMap<>(); + afterKey.put("service", "app_0"); + afterKey.put("host", "server_3"); + when(composite.afterKey()).thenReturn(afterKey); + + String featureID = detector.getFeatureAttributes().get(0).getId(); + List compositeBuckets = new ArrayList<>(); + CompositeAggregation.Bucket bucket = mock(CompositeAggregation.Bucket.class); + when(bucket.getKey()).thenReturn(Collections.singletonMap("app_0", "server_1")); + List aggList = new ArrayList<>(); + aggList.add(new InternalMin(featureID, randomDouble(), DocValueFormat.RAW, new HashMap<>())); + Aggregations aggregations = new Aggregations(aggList); + when(bucket.getAggregations()).thenReturn(aggregations); + compositeBuckets.add(bucket); + + bucket = mock(CompositeAggregation.Bucket.class); + when(bucket.getKey()).thenReturn(Collections.singletonMap("app_0", "server_2")); + aggList = new ArrayList<>(); + aggList.add(new InternalMin(featureID, randomDouble(), DocValueFormat.RAW, new HashMap<>())); + aggregations = new Aggregations(aggList); + when(bucket.getAggregations()).thenReturn(aggregations); + compositeBuckets.add(bucket); + + bucket = mock(CompositeAggregation.Bucket.class); + when(bucket.getKey()).thenReturn(Collections.singletonMap("app_0", "server_3")); + aggList = new ArrayList<>(); + aggList.add(new InternalMin(featureID, randomDouble(), DocValueFormat.RAW, new HashMap<>())); + aggregations = new Aggregations(aggList); + when(bucket.getAggregations()).thenReturn(aggregations); + compositeBuckets.add(bucket); + + when(composite.getBuckets()).thenAnswer((Answer>) invocation -> { return compositeBuckets; }); + Aggregations aggs = new Aggregations(Collections.singletonList(composite)); + + SearchResponseSections sections = new SearchResponseSections(SearchHits.empty(), aggs, null, false, null, null, 1); + SearchResponse response = new SearchResponse(sections, null, 1, 1, 0, 0, ShardSearchFailure.EMPTY_ARRAY, Clusters.EMPTY); + + CountDownLatch inProgress = new CountDownLatch(2); + AtomicBoolean firstCalled = new AtomicBoolean(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + if (firstCalled.get()) { + listener.onResponse(createEmptyResponse()); + inProgress.countDown(); + } else { + listener.onResponse(response); + firstCalled.set(true); + inProgress.countDown(); + } + return null; + }).when(client).search(any(), any()); + + entityResultInterceptor = new TransportInterceptor() { + @Override + public AsyncSender interceptSender(AsyncSender sender) { + return new AsyncSender() { + @SuppressWarnings("unchecked") + @Override + public void sendRequest( + Transport.Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + if (action.equals(EntityResultAction.NAME)) { + sender + .sendRequest( + connection, + action, + request, + options, + interceptor.apply((TransportResponseHandler) handler) + ); + } else { + sender.sendRequest(connection, action, request, options, handler); + } + } + }; + } + }; + + setupTestNodes(entityResultInterceptor, settings, MAX_ENTITIES_PER_QUERY, PAGE_SIZE); + + // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor + when(hashRing.getOwningNode(any(String.class))).thenReturn(Optional.of(testNodes[1].discoveryNode())); + + TransportService realTransportService = testNodes[0].transportService; + ClusterService realClusterService = testNodes[0].clusterService; + + action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + realTransportService, + settings, + client, + stateManager, + featureQuery, + normalModelManager, + normalModelPartitioner, + hashRing, + realClusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + threadPool, + xContentRegistry() + ); + + return inProgress; + } + + public void testNonEmptyFeatures() throws InterruptedException { + CountDownLatch inProgress = setUpTransportInterceptor(this::entityResultHandler); + setUpEntityResult(); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.01); + + assertTrue(inProgress.await(10000L, TimeUnit.MILLISECONDS)); + + // since we have 3 results in the first page + verify(resultWriteQueue, times(3)).put(any()); + } + + @SuppressWarnings("unchecked") + public void testCircuitBreakerOpen() throws InterruptedException { + ClientUtil clientUtil = mock(ClientUtil.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(TestHelpers.createGetResponse(detector, detectorId, AnomalyDetector.ANOMALY_DETECTORS_INDEX)); + return null; + }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); + + ModelPartitioner modelPartitioner = mock(ModelPartitioner.class); + stateManager = new NodeStateManager( + client, + xContentRegistry(), + settings, + clientUtil, + clock, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + modelPartitioner + ); + + action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + stateManager, + featureQuery, + normalModelManager, + normalModelPartitioner, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + mockThreadPool, + xContentRegistry() + ); + + CountDownLatch inProgress = setUpTransportInterceptor(this::entityResultHandler); + + ADCircuitBreakerService openBreaker = mock(ADCircuitBreakerService.class); + when(openBreaker.isOpen()).thenReturn(true); + // register entity result action + new EntityResultTransportAction( + new ActionFilters(Collections.emptySet()), + // since we send requests to testNodes[1] + testNodes[1].transportService, + normalModelManager, + openBreaker, + provider, + stateManager, + indexUtil, + resultWriteQueue, + checkpointReadQueue, + coldEntityQueue, + threadPool + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.01); + + assertTrue(inProgress.await(10000L, TimeUnit.MILLISECONDS)); + + listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + assertException(listener, LimitExceededException.class, CommonErrorMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG); + } + + // public void testNotAck() { + // setUpTransportInterceptor(this::unackEntityResultHandler); + // setUpEntityResult(); + // + // PlainActionFuture listener = new PlainActionFuture<>(); + // + // action.doExecute(null, request, listener); + // + // assertException(listener, InternalFailure.class, AnomalyResultTransportAction.NO_ACK_ERR); + // verify(stateManager, times(1)).addPressure(anyString()); + // } } diff --git a/src/test/java/test/org/opensearch/ad/util/FakeNode.java b/src/test/java/test/org/opensearch/ad/util/FakeNode.java index ca86261aa..bf7a58062 100644 --- a/src/test/java/test/org/opensearch/ad/util/FakeNode.java +++ b/src/test/java/test/org/opensearch/ad/util/FakeNode.java @@ -40,6 +40,8 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.core.Logger; import org.apache.lucene.util.SetOnce; import org.opensearch.Version; import org.opensearch.action.admin.cluster.node.tasks.cancel.TransportCancelTasksAction; @@ -68,6 +70,8 @@ import org.opensearch.transport.nio.MockNioTransport; public class FakeNode implements Releasable { + protected static final Logger LOG = (Logger) LogManager.getLogger(FakeNode.class); + public FakeNode( String name, ThreadPool threadPool,