diff --git a/src/main/java/org/opensearch/ad/client/AnomalyDetectionClient.java b/src/main/java/org/opensearch/ad/client/AnomalyDetectionClient.java index 08ae805ad..7dfd223b9 100644 --- a/src/main/java/org/opensearch/ad/client/AnomalyDetectionClient.java +++ b/src/main/java/org/opensearch/ad/client/AnomalyDetectionClient.java @@ -8,7 +8,8 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.ad.transport.ADTaskProfileResponse; +import org.opensearch.ad.transport.GetAnomalyDetectorRequest; +import org.opensearch.ad.transport.GetAnomalyDetectorResponse; import org.opensearch.common.action.ActionFuture; import org.opensearch.core.action.ActionListener; @@ -54,20 +55,20 @@ default ActionFuture searchAnomalyResults(SearchRequest searchRe /** * Get detector profile - refer to https://opensearch.org/docs/latest/observing-your-data/ad/api/#profile-detector - * @param detectorId the detector ID to fetch the profile for - * @return ActionFuture of ADTaskProfileResponse + * @param profileRequest request to fetch the detector profile + * @return ActionFuture of GetAnomalyDetectorResponse */ - default ActionFuture getDetectorProfile(String detectorId) { - PlainActionFuture actionFuture = PlainActionFuture.newFuture(); - getDetectorProfile(detectorId, actionFuture); + default ActionFuture getDetectorProfile(GetAnomalyDetectorRequest profileRequest) { + PlainActionFuture actionFuture = PlainActionFuture.newFuture(); + getDetectorProfile(profileRequest, actionFuture); return actionFuture; } /** * Get detector profile - refer to https://opensearch.org/docs/latest/observing-your-data/ad/api/#profile-detector - * @param detectorId the detector ID to fetch the profile for + * @param profileRequest request to fetch the detector profile * @param listener a listener to be notified of the result */ - void getDetectorProfile(String detectorId, ActionListener listener); + void getDetectorProfile(GetAnomalyDetectorRequest profileRequest, ActionListener listener); } diff --git a/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java b/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java index 8deddc00a..60bb274ab 100644 --- a/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java +++ b/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java @@ -9,25 +9,20 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; -import org.opensearch.ad.transport.ADTaskProfileAction; -import org.opensearch.ad.transport.ADTaskProfileRequest; -import org.opensearch.ad.transport.ADTaskProfileResponse; +import org.opensearch.ad.transport.GetAnomalyDetectorAction; +import org.opensearch.ad.transport.GetAnomalyDetectorRequest; +import org.opensearch.ad.transport.GetAnomalyDetectorResponse; import org.opensearch.ad.transport.SearchAnomalyDetectorAction; import org.opensearch.ad.transport.SearchAnomalyResultAction; import org.opensearch.client.Client; -import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ActionResponse; -import org.opensearch.timeseries.util.DiscoveryNodeFilterer; public class AnomalyDetectionNodeClient implements AnomalyDetectionClient { private final Client client; - private final DiscoveryNodeFilterer nodeFilterer; - public AnomalyDetectionNodeClient(Client client, ClusterService clusterService) { + public AnomalyDetectionNodeClient(Client client) { this.client = client; - this.nodeFilterer = new DiscoveryNodeFilterer(clusterService); } @Override @@ -45,19 +40,20 @@ public void searchAnomalyResults(SearchRequest searchRequest, ActionListener listener) { - final DiscoveryNode[] eligibleNodes = this.nodeFilterer.getEligibleDataNodes(); - ADTaskProfileRequest profileRequest = new ADTaskProfileRequest(detectorId, eligibleNodes); - this.client.execute(ADTaskProfileAction.INSTANCE, profileRequest, getADTaskProfileResponseActionListener(listener)); + public void getDetectorProfile(GetAnomalyDetectorRequest profileRequest, ActionListener listener) { + this.client.execute(GetAnomalyDetectorAction.INSTANCE, profileRequest, getAnomalyDetectorResponseActionListener(listener)); } // We need to wrap AD-specific response type listeners around an internal listener, and re-generate the response from a generic // ActionResponse. This is needed to prevent classloader issues and ClassCastExceptions when executed by other plugins. - private ActionListener getADTaskProfileResponseActionListener(ActionListener listener) { - ActionListener internalListener = ActionListener - .wrap(profileResponse -> { listener.onResponse(profileResponse); }, listener::onFailure); - ActionListener actionListener = wrapActionListener(internalListener, actionResponse -> { - ADTaskProfileResponse response = ADTaskProfileResponse.fromActionResponse(actionResponse); + private ActionListener getAnomalyDetectorResponseActionListener( + ActionListener listener + ) { + ActionListener internalListener = ActionListener.wrap(getAnomalyDetectorResponse -> { + listener.onResponse(getAnomalyDetectorResponse); + }, listener::onFailure); + ActionListener actionListener = wrapActionListener(internalListener, actionResponse -> { + GetAnomalyDetectorResponse response = GetAnomalyDetectorResponse.fromActionResponse(actionResponse); return response; }); return actionListener; diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java index 1636e6181..f3808dab2 100644 --- a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java @@ -11,13 +11,18 @@ package org.opensearch.ad.transport; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.UncheckedIOException; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.DetectorProfile; import org.opensearch.ad.model.EntityProfile; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.rest.RestStatus; @@ -212,4 +217,19 @@ public ADTask getHistoricalAdTask() { public AnomalyDetector getDetector() { return detector; } + + public static GetAnomalyDetectorResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof GetAnomalyDetectorResponse) { + return (GetAnomalyDetectorResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new GetAnomalyDetectorResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into GetAnomalyDetectorResponse", e); + } + } } diff --git a/src/test/java/org/opensearch/ad/client/AnomalyDetectionClientTests.java b/src/test/java/org/opensearch/ad/client/AnomalyDetectionClientTests.java index 95bfe24d6..c5acc7064 100644 --- a/src/test/java/org/opensearch/ad/client/AnomalyDetectionClientTests.java +++ b/src/test/java/org/opensearch/ad/client/AnomalyDetectionClientTests.java @@ -13,7 +13,9 @@ import org.mockito.MockitoAnnotations; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; -import org.opensearch.ad.transport.ADTaskProfileResponse; +import org.opensearch.ad.transport.GetAnomalyDetectorRequest; +import org.opensearch.ad.transport.GetAnomalyDetectorResponse; +import org.opensearch.common.lucene.uid.Versions; import org.opensearch.core.action.ActionListener; public class AnomalyDetectionClientTests { @@ -27,7 +29,7 @@ public class AnomalyDetectionClientTests { SearchResponse searchResultsResponse; @Mock - ADTaskProfileResponse profileResponse; + GetAnomalyDetectorResponse profileResponse; @Before public void setUp() { @@ -46,7 +48,7 @@ public void searchAnomalyResults(SearchRequest searchRequest, ActionListener listener) { + public void getDetectorProfile(GetAnomalyDetectorRequest profileRequest, ActionListener listener) { listener.onResponse(profileResponse); } }; @@ -64,7 +66,17 @@ public void searchAnomalyResults() { @Test public void getDetectorProfile() { - assertEquals(profileResponse, anomalyDetectionClient.getDetectorProfile("foo").actionGet()); + GetAnomalyDetectorRequest profileRequest = new GetAnomalyDetectorRequest( + "foo", + Versions.MATCH_ANY, + true, + false, + "", + "", + false, + null + ); + assertEquals(profileResponse, anomalyDetectionClient.getDetectorProfile(profileRequest).actionGet()); } } diff --git a/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java b/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java index 6f5eaa37d..c142e5e3d 100644 --- a/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java +++ b/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java @@ -11,41 +11,42 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import static org.opensearch.ad.indices.ADIndexManagement.ALL_AD_RESULTS_INDEX_PATTERN; import static org.opensearch.ad.model.AnomalyDetector.DETECTOR_TYPE_FIELD; +import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; import java.io.IOException; import java.time.Instant; import java.time.temporal.ChronoUnit; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; import java.util.concurrent.ExecutionException; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.junit.Before; import org.junit.Test; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.ad.HistoricalAnalysisIntegTestCase; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.model.ADTaskProfile; +import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyDetectorType; -import org.opensearch.ad.transport.ADTaskProfileAction; -import org.opensearch.ad.transport.ADTaskProfileNodeResponse; -import org.opensearch.ad.transport.ADTaskProfileResponse; +import org.opensearch.ad.model.DetectorProfile; +import org.opensearch.ad.model.DetectorState; +import org.opensearch.ad.transport.GetAnomalyDetectorAction; +import org.opensearch.ad.transport.GetAnomalyDetectorRequest; +import org.opensearch.ad.transport.GetAnomalyDetectorResponse; import org.opensearch.client.Client; -import org.opensearch.cluster.ClusterName; -import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.lucene.uid.Versions; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Job; import com.google.common.collect.ImmutableList; @@ -54,19 +55,16 @@ // The exhaustive set of transport action scenarios are within the respective transport action // test suites themselves. We do not want to unnecessarily duplicate all of those tests here. public class AnomalyDetectionNodeClientTests extends HistoricalAnalysisIntegTestCase { - private final Logger logger = LogManager.getLogger(this.getClass()); - private String indexName = "test-data"; private Instant startTime = Instant.now().minus(2, ChronoUnit.DAYS); private Client clientSpy; private AnomalyDetectionNodeClient adClient; private PlainActionFuture searchResponseFuture; - private PlainActionFuture profileFuture; @Before public void setup() { clientSpy = spy(client()); - adClient = new AnomalyDetectionNodeClient(clientSpy, clusterService()); + adClient = new AnomalyDetectionNodeClient(clientSpy); } @Test @@ -150,39 +148,90 @@ public void testGetDetectorProfile_NoIndices() throws ExecutionException, Interr deleteIndexIfExists(ALL_AD_RESULTS_INDEX_PATTERN); deleteIndexIfExists(ADCommonName.DETECTION_STATE_INDEX); - profileFuture = mock(PlainActionFuture.class); - ADTaskProfileResponse response = adClient.getDetectorProfile("foo").actionGet(10000); - List responses = response.getNodes(); - - assertNotEquals(0, responses.size()); - assertEquals(null, responses.get(0).getAdTaskProfile()); - verify(clientSpy, times(1)).execute(any(ADTaskProfileAction.class), any(), any()); - + GetAnomalyDetectorRequest profileRequest = new GetAnomalyDetectorRequest( + "foo", + Versions.MATCH_ANY, + true, + false, + "", + "", + false, + null + ); + + OpenSearchStatusException exception = expectThrows( + OpenSearchStatusException.class, + () -> adClient.getDetectorProfile(profileRequest).actionGet(10000) + ); + + assertTrue(exception.getMessage().contains(FAIL_TO_FIND_CONFIG_MSG)); + verify(clientSpy, times(1)).execute(any(GetAnomalyDetectorAction.class), any(), any()); } @Test - public void testGetDetectorProfile_Populated() { - DiscoveryNode localNode = clusterService().localNode(); - ADTaskProfile adTaskProfile = new ADTaskProfile("foo-task-id", 0, 0L, false, 0, 0L, localNode.getId()); + public void testGetDetectorProfile_Populated() throws IOException { + ingestTestData(indexName, startTime, 1, "test", 10); + AnomalyDetector detector = TestHelpers + .randomAnomalyDetector( + ImmutableList.of(indexName), + ImmutableList.of(TestHelpers.randomFeature(true)), + null, + Instant.now(), + 1, + false, + null + ); + createDetectorIndex(); + String detectorId = createDetector(detector); doAnswer(invocation -> { Object[] args = invocation.getArguments(); - - ActionListener listener = (ActionListener) args[2]; - ADTaskProfileNodeResponse nodeResponse = new ADTaskProfileNodeResponse(localNode, adTaskProfile, null); - - List nodeResponses = Arrays.asList(nodeResponse); - listener.onResponse(new ADTaskProfileResponse(new ClusterName("test-cluster"), nodeResponses, Collections.emptyList())); + ActionListener listener = (ActionListener) args[2]; + + // Setting up mock profile to test that the state is returned correctly in the client response + DetectorProfile mockProfile = mock(DetectorProfile.class); + when(mockProfile.getState()).thenReturn(DetectorState.DISABLED); + + GetAnomalyDetectorResponse response = new GetAnomalyDetectorResponse( + 1234, + "4567", + 9876, + 2345, + detector, + mock(Job.class), + false, + mock(ADTask.class), + mock(ADTask.class), + false, + RestStatus.OK, + mockProfile, + null, + false + ); + listener.onResponse(response); return null; - }).when(clientSpy).execute(any(ADTaskProfileAction.class), any(), any()); - - ADTaskProfileResponse response = adClient.getDetectorProfile("foo").actionGet(10000); - String responseTaskId = response.getNodes().get(0).getAdTaskProfile().getTaskId(); - - assertNotEquals(0, response.getNodes().size()); - assertEquals(responseTaskId, adTaskProfile.getTaskId()); - verify(clientSpy, times(1)).execute(any(ADTaskProfileAction.class), any(), any()); + }).when(clientSpy).execute(any(GetAnomalyDetectorAction.class), any(), any()); + + GetAnomalyDetectorRequest profileRequest = new GetAnomalyDetectorRequest( + detectorId, + Versions.MATCH_ANY, + true, + false, + "", + "", + false, + null + ); + + GetAnomalyDetectorResponse response = adClient.getDetectorProfile(profileRequest).actionGet(10000); + + assertNotEquals(null, response.getDetector()); + assertNotEquals(null, response.getDetectorProfile()); + assertEquals(null, response.getAdJob()); + assertEquals(detector.getName(), response.getDetector().getName()); + assertEquals(DetectorState.DISABLED, response.getDetectorProfile().getState()); + verify(clientSpy, times(1)).execute(any(GetAnomalyDetectorAction.class), any(), any()); } }