From b87ded2125f551dfc336a3ca82e1bfb119c4cf82 Mon Sep 17 00:00:00 2001 From: Tyler Ohlsen Date: Wed, 27 Dec 2023 22:40:00 +0000 Subject: [PATCH] Refactor client profile method to use GetAnomalyDetectorTransportAction Signed-off-by: Tyler Ohlsen --- .../ad/client/AnomalyDetectionClient.java | 10 +- .../ad/client/AnomalyDetectionNodeClient.java | 34 ++++--- .../transport/GetAnomalyDetectorResponse.java | 20 ++++ .../client/AnomalyDetectionClientTests.java | 6 +- .../AnomalyDetectionNodeClientTests.java | 95 ++++++++++++------- 5 files changed, 106 insertions(+), 59 deletions(-) diff --git a/src/main/java/org/opensearch/ad/client/AnomalyDetectionClient.java b/src/main/java/org/opensearch/ad/client/AnomalyDetectionClient.java index 08ae805ad..fb1efb5a0 100644 --- a/src/main/java/org/opensearch/ad/client/AnomalyDetectionClient.java +++ b/src/main/java/org/opensearch/ad/client/AnomalyDetectionClient.java @@ -8,7 +8,7 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.ad.transport.ADTaskProfileResponse; +import org.opensearch.ad.transport.GetAnomalyDetectorResponse; import org.opensearch.common.action.ActionFuture; import org.opensearch.core.action.ActionListener; @@ -55,10 +55,10 @@ default ActionFuture searchAnomalyResults(SearchRequest searchRe /** * Get detector profile - refer to https://opensearch.org/docs/latest/observing-your-data/ad/api/#profile-detector * @param detectorId the detector ID to fetch the profile for - * @return ActionFuture of ADTaskProfileResponse + * @return ActionFuture of GetAnomalyDetectorResponse */ - default ActionFuture getDetectorProfile(String detectorId) { - PlainActionFuture actionFuture = PlainActionFuture.newFuture(); + default ActionFuture getDetectorProfile(String detectorId) { + PlainActionFuture actionFuture = PlainActionFuture.newFuture(); getDetectorProfile(detectorId, actionFuture); return actionFuture; } @@ -68,6 +68,6 @@ default ActionFuture getDetectorProfile(String detectorId * @param detectorId the detector ID to fetch the profile for * @param listener a listener to be notified of the result */ - void getDetectorProfile(String detectorId, ActionListener listener); + void getDetectorProfile(String detectorId, ActionListener listener); } diff --git a/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java b/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java index 8deddc00a..9aecb6b37 100644 --- a/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java +++ b/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java @@ -9,25 +9,21 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; -import org.opensearch.ad.transport.ADTaskProfileAction; -import org.opensearch.ad.transport.ADTaskProfileRequest; -import org.opensearch.ad.transport.ADTaskProfileResponse; +import org.opensearch.ad.transport.GetAnomalyDetectorAction; +import org.opensearch.ad.transport.GetAnomalyDetectorRequest; +import org.opensearch.ad.transport.GetAnomalyDetectorResponse; import org.opensearch.ad.transport.SearchAnomalyDetectorAction; import org.opensearch.ad.transport.SearchAnomalyResultAction; import org.opensearch.client.Client; -import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.lucene.uid.Versions; import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ActionResponse; -import org.opensearch.timeseries.util.DiscoveryNodeFilterer; public class AnomalyDetectionNodeClient implements AnomalyDetectionClient { private final Client client; - private final DiscoveryNodeFilterer nodeFilterer; - public AnomalyDetectionNodeClient(Client client, ClusterService clusterService) { + public AnomalyDetectionNodeClient(Client client) { this.client = client; - this.nodeFilterer = new DiscoveryNodeFilterer(clusterService); } @Override @@ -45,19 +41,21 @@ public void searchAnomalyResults(SearchRequest searchRequest, ActionListener listener) { - final DiscoveryNode[] eligibleNodes = this.nodeFilterer.getEligibleDataNodes(); - ADTaskProfileRequest profileRequest = new ADTaskProfileRequest(detectorId, eligibleNodes); - this.client.execute(ADTaskProfileAction.INSTANCE, profileRequest, getADTaskProfileResponseActionListener(listener)); + public void getDetectorProfile(String detectorId, ActionListener listener) { + GetAnomalyDetectorRequest request = new GetAnomalyDetectorRequest(detectorId, Versions.MATCH_ANY, true, false, "", "", false, null); + this.client.execute(GetAnomalyDetectorAction.INSTANCE, request, getAnomalyDetectorResponseActionListener(listener)); } // We need to wrap AD-specific response type listeners around an internal listener, and re-generate the response from a generic // ActionResponse. This is needed to prevent classloader issues and ClassCastExceptions when executed by other plugins. - private ActionListener getADTaskProfileResponseActionListener(ActionListener listener) { - ActionListener internalListener = ActionListener - .wrap(profileResponse -> { listener.onResponse(profileResponse); }, listener::onFailure); - ActionListener actionListener = wrapActionListener(internalListener, actionResponse -> { - ADTaskProfileResponse response = ADTaskProfileResponse.fromActionResponse(actionResponse); + private ActionListener getAnomalyDetectorResponseActionListener( + ActionListener listener + ) { + ActionListener internalListener = ActionListener.wrap(getAnomalyDetectorResponse -> { + listener.onResponse(getAnomalyDetectorResponse); + }, listener::onFailure); + ActionListener actionListener = wrapActionListener(internalListener, actionResponse -> { + GetAnomalyDetectorResponse response = GetAnomalyDetectorResponse.fromActionResponse(actionResponse); return response; }); return actionListener; diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java index 1636e6181..f3808dab2 100644 --- a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java @@ -11,13 +11,18 @@ package org.opensearch.ad.transport; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.UncheckedIOException; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.DetectorProfile; import org.opensearch.ad.model.EntityProfile; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.rest.RestStatus; @@ -212,4 +217,19 @@ public ADTask getHistoricalAdTask() { public AnomalyDetector getDetector() { return detector; } + + public static GetAnomalyDetectorResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof GetAnomalyDetectorResponse) { + return (GetAnomalyDetectorResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new GetAnomalyDetectorResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into GetAnomalyDetectorResponse", e); + } + } } diff --git a/src/test/java/org/opensearch/ad/client/AnomalyDetectionClientTests.java b/src/test/java/org/opensearch/ad/client/AnomalyDetectionClientTests.java index 95bfe24d6..6a92932a0 100644 --- a/src/test/java/org/opensearch/ad/client/AnomalyDetectionClientTests.java +++ b/src/test/java/org/opensearch/ad/client/AnomalyDetectionClientTests.java @@ -13,7 +13,7 @@ import org.mockito.MockitoAnnotations; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; -import org.opensearch.ad.transport.ADTaskProfileResponse; +import org.opensearch.ad.transport.GetAnomalyDetectorResponse; import org.opensearch.core.action.ActionListener; public class AnomalyDetectionClientTests { @@ -27,7 +27,7 @@ public class AnomalyDetectionClientTests { SearchResponse searchResultsResponse; @Mock - ADTaskProfileResponse profileResponse; + GetAnomalyDetectorResponse profileResponse; @Before public void setUp() { @@ -46,7 +46,7 @@ public void searchAnomalyResults(SearchRequest searchRequest, ActionListener listener) { + public void getDetectorProfile(String detectorId, ActionListener listener) { listener.onResponse(profileResponse); } }; diff --git a/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java b/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java index 6f5eaa37d..49dda6782 100644 --- a/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java +++ b/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java @@ -11,41 +11,42 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import static org.opensearch.ad.indices.ADIndexManagement.ALL_AD_RESULTS_INDEX_PATTERN; import static org.opensearch.ad.model.AnomalyDetector.DETECTOR_TYPE_FIELD; +import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; import java.io.IOException; import java.time.Instant; import java.time.temporal.ChronoUnit; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; import java.util.concurrent.ExecutionException; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.junit.Before; import org.junit.Test; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.ad.HistoricalAnalysisIntegTestCase; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.model.ADTaskProfile; +import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyDetectorType; -import org.opensearch.ad.transport.ADTaskProfileAction; -import org.opensearch.ad.transport.ADTaskProfileNodeResponse; -import org.opensearch.ad.transport.ADTaskProfileResponse; +import org.opensearch.ad.model.DetectorProfile; +import org.opensearch.ad.model.DetectorState; +import org.opensearch.ad.transport.GetAnomalyDetectorAction; +import org.opensearch.ad.transport.GetAnomalyDetectorResponse; import org.opensearch.client.Client; -import org.opensearch.cluster.ClusterName; -import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Job; import com.google.common.collect.ImmutableList; @@ -61,12 +62,11 @@ public class AnomalyDetectionNodeClientTests extends HistoricalAnalysisIntegTest private Client clientSpy; private AnomalyDetectionNodeClient adClient; private PlainActionFuture searchResponseFuture; - private PlainActionFuture profileFuture; @Before public void setup() { clientSpy = spy(client()); - adClient = new AnomalyDetectionNodeClient(clientSpy, clusterService()); + adClient = new AnomalyDetectionNodeClient(clientSpy); } @Test @@ -150,39 +150,68 @@ public void testGetDetectorProfile_NoIndices() throws ExecutionException, Interr deleteIndexIfExists(ALL_AD_RESULTS_INDEX_PATTERN); deleteIndexIfExists(ADCommonName.DETECTION_STATE_INDEX); - profileFuture = mock(PlainActionFuture.class); - ADTaskProfileResponse response = adClient.getDetectorProfile("foo").actionGet(10000); - List responses = response.getNodes(); - - assertNotEquals(0, responses.size()); - assertEquals(null, responses.get(0).getAdTaskProfile()); - verify(clientSpy, times(1)).execute(any(ADTaskProfileAction.class), any(), any()); + OpenSearchStatusException exception = expectThrows( + OpenSearchStatusException.class, + () -> adClient.getDetectorProfile("foo").actionGet(10000) + ); + assertTrue(exception.getMessage().contains(FAIL_TO_FIND_CONFIG_MSG)); + verify(clientSpy, times(1)).execute(any(GetAnomalyDetectorAction.class), any(), any()); } @Test - public void testGetDetectorProfile_Populated() { - DiscoveryNode localNode = clusterService().localNode(); - ADTaskProfile adTaskProfile = new ADTaskProfile("foo-task-id", 0, 0L, false, 0, 0L, localNode.getId()); + public void testGetDetectorProfile_Populated() throws IOException { + ingestTestData(indexName, startTime, 1, "test", 10); + AnomalyDetector detector = TestHelpers + .randomAnomalyDetector( + ImmutableList.of(indexName), + ImmutableList.of(TestHelpers.randomFeature(true)), + null, + Instant.now(), + 1, + false, + null + ); + createDetectorIndex(); + String detectorId = createDetector(detector); doAnswer(invocation -> { Object[] args = invocation.getArguments(); - - ActionListener listener = (ActionListener) args[2]; - ADTaskProfileNodeResponse nodeResponse = new ADTaskProfileNodeResponse(localNode, adTaskProfile, null); - - List nodeResponses = Arrays.asList(nodeResponse); - listener.onResponse(new ADTaskProfileResponse(new ClusterName("test-cluster"), nodeResponses, Collections.emptyList())); + ActionListener listener = (ActionListener) args[2]; + + // Setting up mock profile to test that the state is returned correctly in the client response + DetectorProfile mockProfile = mock(DetectorProfile.class); + when(mockProfile.getState()).thenReturn(DetectorState.DISABLED); + + GetAnomalyDetectorResponse response = new GetAnomalyDetectorResponse( + 1234, + "4567", + 9876, + 2345, + detector, + mock(Job.class), + false, + mock(ADTask.class), + mock(ADTask.class), + false, + RestStatus.OK, + mockProfile, + null, + false + ); + listener.onResponse(response); return null; - }).when(clientSpy).execute(any(ADTaskProfileAction.class), any(), any()); + }).when(clientSpy).execute(any(GetAnomalyDetectorAction.class), any(), any()); - ADTaskProfileResponse response = adClient.getDetectorProfile("foo").actionGet(10000); - String responseTaskId = response.getNodes().get(0).getAdTaskProfile().getTaskId(); + GetAnomalyDetectorResponse response = adClient.getDetectorProfile(detectorId).actionGet(10000); - assertNotEquals(0, response.getNodes().size()); - assertEquals(responseTaskId, adTaskProfile.getTaskId()); - verify(clientSpy, times(1)).execute(any(ADTaskProfileAction.class), any(), any()); + assertNotEquals(null, response.getDetector()); + assertNotEquals(null, response.getDetectorProfile()); + assertEquals(null, response.getAdJob()); + assertEquals(detector.getName(), response.getDetector().getName()); + assertEquals(DetectorState.DISABLED, response.getDetectorProfile().getState()); + verify(clientSpy, times(1)).execute(any(GetAnomalyDetectorAction.class), any(), any()); } }