Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add profile transport action to AD client #1119

Merged
merged 10 commits into from
Dec 27, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.ad.transport.ADTaskProfileResponse;
import org.opensearch.common.action.ActionFuture;
import org.opensearch.core.action.ActionListener;

Expand Down Expand Up @@ -40,7 +41,7 @@ default ActionFuture<SearchResponse> searchAnomalyDetectors(SearchRequest search
*/
default ActionFuture<SearchResponse> searchAnomalyResults(SearchRequest searchRequest) {
PlainActionFuture<SearchResponse> actionFuture = PlainActionFuture.newFuture();
searchAnomalyDetectors(searchRequest, actionFuture);
searchAnomalyResults(searchRequest, actionFuture);
return actionFuture;
}

Expand All @@ -51,4 +52,22 @@ default ActionFuture<SearchResponse> searchAnomalyResults(SearchRequest searchRe
*/
void searchAnomalyResults(SearchRequest searchRequest, ActionListener<SearchResponse> listener);

/**
* Get detector profile - refer to https://opensearch.org/docs/latest/observing-your-data/ad/api/#profile-detector
* @param detectorId the detector ID to fetch the profile for
* @return ActionFuture of ADTaskProfileResponse
*/
default ActionFuture<ADTaskProfileResponse> getDetectorProfile(String detectorId) {
PlainActionFuture<ADTaskProfileResponse> actionFuture = PlainActionFuture.newFuture();
getDetectorProfile(detectorId, actionFuture);
return actionFuture;
}

/**
* Get detector profile - refer to https://opensearch.org/docs/latest/observing-your-data/ad/api/#profile-detector
* @param detectorId the detector ID to fetch the profile for
* @param listener a listener to be notified of the result
*/
void getDetectorProfile(String detectorId, ActionListener<ADTaskProfileResponse> listener);

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,29 @@

package org.opensearch.ad.client;

import java.util.function.Function;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.ad.transport.ADTaskProfileAction;
import org.opensearch.ad.transport.ADTaskProfileRequest;
import org.opensearch.ad.transport.ADTaskProfileResponse;
import org.opensearch.ad.transport.SearchAnomalyDetectorAction;
import org.opensearch.ad.transport.SearchAnomalyResultAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.timeseries.util.DiscoveryNodeFilterer;

public class AnomalyDetectionNodeClient implements AnomalyDetectionClient {
private final Client client;
private final DiscoveryNodeFilterer nodeFilterer;

public AnomalyDetectionNodeClient(Client client) {
public AnomalyDetectionNodeClient(Client client, ClusterService clusterService) {
this.client = client;
this.nodeFilterer = new DiscoveryNodeFilterer(clusterService);
}

@Override
Expand All @@ -32,4 +43,34 @@ public void searchAnomalyResults(SearchRequest searchRequest, ActionListener<Sea
listener.onResponse(searchResponse);
}, listener::onFailure));
}

@Override
public void getDetectorProfile(String detectorId, ActionListener<ADTaskProfileResponse> listener) {
final DiscoveryNode[] eligibleNodes = this.nodeFilterer.getEligibleDataNodes();
ADTaskProfileRequest profileRequest = new ADTaskProfileRequest(detectorId, eligibleNodes);
this.client.execute(ADTaskProfileAction.INSTANCE, profileRequest, getADTaskProfileResponseActionListener(listener));
}

// We need to wrap AD-specific response type listeners around an internal listener, and re-generate the response from a generic
// ActionResponse. This is needed to prevent classloader issues and ClassCastExceptions when executed by other plugins.
private ActionListener<ADTaskProfileResponse> getADTaskProfileResponseActionListener(ActionListener<ADTaskProfileResponse> listener) {
ActionListener<ADTaskProfileResponse> internalListener = ActionListener
.wrap(profileResponse -> { listener.onResponse(profileResponse); }, listener::onFailure);
ActionListener<ADTaskProfileResponse> actionListener = wrapActionListener(internalListener, actionResponse -> {
ADTaskProfileResponse response = ADTaskProfileResponse.fromActionResponse(actionResponse);
return response;
});
return actionListener;
}

private <T extends ActionResponse> ActionListener<T> wrapActionListener(
final ActionListener<T> listener,
final Function<ActionResponse, T> recreate
) {
ActionListener<T> actionListener = ActionListener.wrap(r -> {
listener.onResponse(recreate.apply(r));
;
}, e -> { listener.onFailure(e); });
return actionListener;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,18 @@

package org.opensearch.ad.transport;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.List;

import org.opensearch.action.FailedNodeException;
import org.opensearch.action.support.nodes.BaseNodesResponse;
import org.opensearch.cluster.ClusterName;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;

Expand All @@ -40,4 +46,18 @@ public List<ADTaskProfileNodeResponse> readNodesFrom(StreamInput in) throws IOEx
return in.readList(ADTaskProfileNodeResponse::readNodeResponse);
}

public static ADTaskProfileResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof ADTaskProfileResponse) {
return (ADTaskProfileResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new ADTaskProfileResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into ADTaskProfileResponse", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,21 @@
import org.mockito.MockitoAnnotations;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.ad.transport.ADTaskProfileResponse;
import org.opensearch.core.action.ActionListener;

public class AnomalyDetectionClientTests {

AnomalyDetectionClient anomalyDetectionClient;

@Mock
SearchResponse searchResponse;
SearchResponse searchDetectorsResponse;

@Mock
SearchResponse searchResultsResponse;

@Mock
ADTaskProfileResponse profileResponse;

@Before
public void setUp() {
Expand All @@ -30,24 +37,34 @@ public void setUp() {
anomalyDetectionClient = new AnomalyDetectionClient() {
@Override
public void searchAnomalyDetectors(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
listener.onResponse(searchResponse);
listener.onResponse(searchDetectorsResponse);
}

@Override
public void searchAnomalyResults(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
listener.onResponse(searchResponse);
listener.onResponse(searchResultsResponse);
}

@Override
public void getDetectorProfile(String detectorId, ActionListener<ADTaskProfileResponse> listener) {
listener.onResponse(profileResponse);
}
};
}

@Test
public void searchAnomalyDetectors() {
assertEquals(searchResponse, anomalyDetectionClient.searchAnomalyDetectors(new SearchRequest()).actionGet());
assertEquals(searchDetectorsResponse, anomalyDetectionClient.searchAnomalyDetectors(new SearchRequest()).actionGet());
}

@Test
public void searchAnomalyResults() {
assertEquals(searchResponse, anomalyDetectionClient.searchAnomalyResults(new SearchRequest()).actionGet());
assertEquals(searchResultsResponse, anomalyDetectionClient.searchAnomalyResults(new SearchRequest()).actionGet());
}

@Test
public void getDetectorProfile() {
assertEquals(profileResponse, anomalyDetectionClient.getDetectorProfile("foo").actionGet());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,47 @@

package org.opensearch.ad.client;

import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.opensearch.ad.indices.ADIndexManagement.ALL_AD_RESULTS_INDEX_PATTERN;
import static org.opensearch.ad.model.AnomalyDetector.DETECTOR_TYPE_FIELD;
import static org.opensearch.timeseries.TestHelpers.matchAllRequest;

import java.io.IOException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutionException;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.junit.Before;
import org.junit.Test;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.ad.HistoricalAnalysisIntegTestCase;
import org.opensearch.ad.constant.ADCommonName;
import org.opensearch.ad.model.ADTaskProfile;
import org.opensearch.ad.model.AnomalyDetector;
import org.opensearch.ad.model.AnomalyDetectorType;
import org.opensearch.ad.transport.ADTaskProfileAction;
import org.opensearch.ad.transport.ADTaskProfileNodeResponse;
import org.opensearch.ad.transport.ADTaskProfileResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.ClusterName;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.timeseries.TestHelpers;
import org.opensearch.timeseries.constant.CommonName;

import com.google.common.collect.ImmutableList;

Expand All @@ -38,22 +54,26 @@
// The exhaustive set of transport action scenarios are within the respective transport action
// test suites themselves. We do not want to unnecessarily duplicate all of those tests here.
public class AnomalyDetectionNodeClientTests extends HistoricalAnalysisIntegTestCase {
private final Logger logger = LogManager.getLogger(this.getClass());

private String indexName = "test-data";
private Instant startTime = Instant.now().minus(2, ChronoUnit.DAYS);
private Client clientSpy;
private AnomalyDetectionNodeClient adClient;
private PlainActionFuture<SearchResponse> future;
private PlainActionFuture<SearchResponse> searchResponseFuture;
private PlainActionFuture<ADTaskProfileResponse> profileFuture;

@Before
public void setup() {
adClient = new AnomalyDetectionNodeClient(client());
clientSpy = spy(client());
adClient = new AnomalyDetectionNodeClient(clientSpy, clusterService());
}

@Test
public void testSearchAnomalyDetectors_NoIndices() {
deleteIndexIfExists(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS);

SearchResponse searchResponse = adClient.searchAnomalyDetectors(matchAllRequest()).actionGet(10000);
SearchResponse searchResponse = adClient.searchAnomalyDetectors(TestHelpers.matchAllRequest()).actionGet(10000);
assertEquals(0, searchResponse.getInternalResponse().hits().getTotalHits().value);
}

Expand All @@ -62,13 +82,13 @@ public void testSearchAnomalyDetectors_Empty() throws IOException {
deleteIndexIfExists(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS);
createDetectorIndex();

SearchResponse searchResponse = adClient.searchAnomalyDetectors(matchAllRequest()).actionGet(10000);
SearchResponse searchResponse = adClient.searchAnomalyDetectors(TestHelpers.matchAllRequest()).actionGet(10000);
assertEquals(0, searchResponse.getInternalResponse().hits().getTotalHits().value);
}

@Test
public void searchAnomalyDetectors_Populated() throws IOException {
ingestTestData(indexName, startTime, 1, "test", 3000);
ingestTestData(indexName, startTime, 1, "test", 10);
String detectorType = AnomalyDetectorType.SINGLE_ENTITY.name();
AnomalyDetector detector = TestHelpers
.randomAnomalyDetector(
Expand All @@ -94,18 +114,18 @@ public void searchAnomalyDetectors_Populated() throws IOException {

@Test
public void testSearchAnomalyResults_NoIndices() {
future = mock(PlainActionFuture.class);
searchResponseFuture = mock(PlainActionFuture.class);
SearchRequest request = new SearchRequest().indices(new String[] {});

adClient.searchAnomalyResults(request, future);
verify(future).onFailure(any(IllegalArgumentException.class));
adClient.searchAnomalyResults(request, searchResponseFuture);
verify(searchResponseFuture).onFailure(any(IllegalArgumentException.class));
}

@Test
public void testSearchAnomalyResults_Empty() throws IOException {
createADResultIndex();
SearchResponse searchResponse = adClient
.searchAnomalyResults(matchAllRequest().indices(ALL_AD_RESULTS_INDEX_PATTERN))
.searchAnomalyResults(TestHelpers.matchAllRequest().indices(ALL_AD_RESULTS_INDEX_PATTERN))
.actionGet(10000);
assertEquals(0, searchResponse.getInternalResponse().hits().getTotalHits().value);
}
Expand All @@ -117,11 +137,52 @@ public void testSearchAnomalyResults_Populated() throws IOException {
String adResultId = createADResult(TestHelpers.randomAnomalyDetectResult());

SearchResponse searchResponse = adClient
.searchAnomalyResults(matchAllRequest().indices(ALL_AD_RESULTS_INDEX_PATTERN))
.searchAnomalyResults(TestHelpers.matchAllRequest().indices(ALL_AD_RESULTS_INDEX_PATTERN))
.actionGet(10000);
assertEquals(1, searchResponse.getInternalResponse().hits().getTotalHits().value);

assertEquals(1, searchResponse.getInternalResponse().hits().getTotalHits().value);
assertEquals(adResultId, searchResponse.getInternalResponse().hits().getAt(0).getId());
}

@Test
public void testGetDetectorProfile_NoIndices() throws ExecutionException, InterruptedException {
deleteIndexIfExists(CommonName.CONFIG_INDEX);
deleteIndexIfExists(ALL_AD_RESULTS_INDEX_PATTERN);
deleteIndexIfExists(ADCommonName.DETECTION_STATE_INDEX);

profileFuture = mock(PlainActionFuture.class);
ADTaskProfileResponse response = adClient.getDetectorProfile("foo").actionGet(10000);
List<ADTaskProfileNodeResponse> responses = response.getNodes();

assertNotEquals(0, responses.size());
assertEquals(null, responses.get(0).getAdTaskProfile());
verify(clientSpy, times(1)).execute(any(ADTaskProfileAction.class), any(), any());

}

@Test
public void testGetDetectorProfile_Populated() {
DiscoveryNode localNode = clusterService().localNode();
ADTaskProfile adTaskProfile = new ADTaskProfile("foo-task-id", 0, 0L, false, 0, 0L, localNode.getId());

doAnswer(invocation -> {
Object[] args = invocation.getArguments();

ActionListener<ADTaskProfileResponse> listener = (ActionListener<ADTaskProfileResponse>) args[2];
ADTaskProfileNodeResponse nodeResponse = new ADTaskProfileNodeResponse(localNode, adTaskProfile, null);

List<ADTaskProfileNodeResponse> nodeResponses = Arrays.asList(nodeResponse);
listener.onResponse(new ADTaskProfileResponse(new ClusterName("test-cluster"), nodeResponses, Collections.emptyList()));

return null;
}).when(clientSpy).execute(any(ADTaskProfileAction.class), any(), any());

ADTaskProfileResponse response = adClient.getDetectorProfile("foo").actionGet(10000);
String responseTaskId = response.getNodes().get(0).getAdTaskProfile().getTaskId();

assertNotEquals(0, response.getNodes().size());
assertEquals(responseTaskId, adTaskProfile.getTaskId());
verify(clientSpy, times(1)).execute(any(ADTaskProfileAction.class), any(), any());
}

}
Loading
Loading