Skip to content

Commit

Permalink
Fix detector state params in SearchAnomalyDetectorsTool (#235)
Browse files Browse the repository at this point in the history
* Inject NamedWriteableRegistry into AD tools

Signed-off-by: Tyler Ohlsen <[email protected]>

* Fix detector state param filtering

Signed-off-by: Tyler Ohlsen <[email protected]>

* Enforce ordering on search AD ITs

Signed-off-by: Tyler Ohlsen <[email protected]>

* Clean up filtering logic more; fix and add UT/IT

Signed-off-by: Tyler Ohlsen <[email protected]>

---------

Signed-off-by: Tyler Ohlsen <[email protected]>
  • Loading branch information
ohltyler authored Feb 26, 2024
1 parent bd510a5 commit 44dc232
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 75 deletions.
4 changes: 2 additions & 2 deletions src/main/java/org/opensearch/agent/ToolPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ public Collection<Object> createComponents(
SearchIndexTool.Factory.getInstance().init(client, xContentRegistry);
RAGTool.Factory.getInstance().init(client, xContentRegistry);
SearchAlertsTool.Factory.getInstance().init(client);
SearchAnomalyDetectorsTool.Factory.getInstance().init(client);
SearchAnomalyResultsTool.Factory.getInstance().init(client);
SearchAnomalyDetectorsTool.Factory.getInstance().init(client, namedWriteableRegistry);
SearchAnomalyResultsTool.Factory.getInstance().init(client, namedWriteableRegistry);
SearchMonitorsTool.Factory.getInstance().init(client);
return Collections.emptyList();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
Expand All @@ -26,6 +27,7 @@
import org.opensearch.client.Client;
import org.opensearch.common.lucene.uid.Versions;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
Expand All @@ -49,7 +51,7 @@
public class SearchAnomalyDetectorsTool implements Tool {
public static final String TYPE = "SearchAnomalyDetectorsTool";
private static final String DEFAULT_DESCRIPTION =
"This is a tool that searches anomaly detectors. It takes 12 optional arguments named detectorName which is the explicit name of the monitor (default is null), and detectorNamePattern which is a wildcard query to match detector name (default is null), and indices which defines the index or index pattern the detector is detecting over (default is null), and highCardinality which defines whether the anomaly detector is high cardinality (synonymous with multi-entity) of non-high-cardinality (synonymous with single-entity) (default is null, indicating both), and lastUpdateTime which defines the latest update time of the anomaly detector in epoch milliseconds (default is null), and sortOrder which defines the order of the results (options are asc or desc, and default is asc), and sortString which defines how to sort the results (default is name.keyword), and size which defines the size of the request to be returned (default is 20), and startIndex which defines the paginated index to start from (default is 0), and running which defines whether the anomaly detector is running (default is null, indicating both), and disabled which defines whether the anomaly detector is disabled (default is null, indicating both), and failed which defines whether the anomaly detector has failed (default is null, indicating both). The tool returns 2 values: a list of anomaly detectors (each containing the detector id, detector name, detector type indicating multi-entity or single-entity (where multi-entity also means high-cardinality), detector description, name of the configured index, last update time in epoch milliseconds), and the total number of anomaly detectors.";
"This is a tool that searches anomaly detectors. It takes 12 optional arguments named detectorName which is the explicit name of the detector (default is null), and detectorNamePattern which is a wildcard query to match detector name (default is null), and indices which defines the index or index pattern the detector is detecting over (default is null), and highCardinality which defines whether the anomaly detector is high cardinality (synonymous with multi-entity) of non-high-cardinality (synonymous with single-entity) (default is null, indicating both), and lastUpdateTime which defines the latest update time of the anomaly detector in epoch milliseconds (default is null), and sortOrder which defines the order of the results (options are asc or desc, and default is asc), and sortString which defines how to sort the results (default is name.keyword), and size which defines the size of the request to be returned (default is 20), and startIndex which defines the paginated index to start from (default is 0), and running which defines whether the anomaly detector is running (default is null, indicating both), and failed which defines whether the anomaly detector has failed (default is null, indicating both). The tool returns 2 values: a list of anomaly detectors (each containing the detector id, detector name, detector type indicating multi-entity or single-entity (where multi-entity also means high-cardinality), detector description, name of the configured index, last update time in epoch milliseconds), and the total number of anomaly detectors.";

@Setter
@Getter
Expand All @@ -70,9 +72,9 @@ public class SearchAnomalyDetectorsTool implements Tool {
@Setter
private Parser<?, ?> outputParser;

public SearchAnomalyDetectorsTool(Client client) {
public SearchAnomalyDetectorsTool(Client client, NamedWriteableRegistry namedWriteableRegistry) {
this.client = client;
this.adClient = new AnomalyDetectionNodeClient(client);
this.adClient = new AnomalyDetectionNodeClient(client, namedWriteableRegistry);

// probably keep this overridden output parser. need to ensure the output matches what's expected
outputParser = new Parser<>() {
Expand Down Expand Up @@ -105,7 +107,6 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
final int size = parameters.containsKey("size") ? Integer.parseInt(parameters.get("size")) : 20;
final int startIndex = parameters.containsKey("startIndex") ? Integer.parseInt(parameters.get("startIndex")) : 0;
final Boolean running = parameters.containsKey("running") ? Boolean.parseBoolean(parameters.get("running")) : null;
final Boolean disabled = parameters.containsKey("disabled") ? Boolean.parseBoolean(parameters.get("disabled")) : null;
final Boolean failed = parameters.containsKey("failed") ? Boolean.parseBoolean(parameters.get("failed")) : null;

List<QueryBuilder> mustList = new ArrayList<QueryBuilder>();
Expand Down Expand Up @@ -139,10 +140,16 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
ActionListener<SearchResponse> searchDetectorListener = ActionListener.<SearchResponse>wrap(response -> {
StringBuilder sb = new StringBuilder();
List<SearchHit> hits = Arrays.asList(response.getHits().getHits());
Map<String, SearchHit> hitsAsMap = hits.stream().collect(Collectors.toMap(SearchHit::getId, hit -> hit));
Map<String, SearchHit> hitsAsMap = new HashMap<>();
// We persist the hits map using detector name as the key. Note this is required to be unique from the AD plugin.
// We cannot use detector ID, because the detector in the response from the profile transport action does not include this,
// making it difficult to map potential hits that should be removed later on based on the profile response's detector state.
for (SearchHit hit : hits) {
hitsAsMap.put((String) hit.getSourceAsMap().get("name"), hit);
}

// If we need to filter by detector state, make subsequent profile API calls to each detector
if (running != null || disabled != null || failed != null) {
if (running != null || failed != null) {
List<CompletableFuture<GetAnomalyDetectorResponse>> profileFutures = new ArrayList<>();
for (SearchHit hit : hits) {
CompletableFuture<GetAnomalyDetectorResponse> profileFuture = new CompletableFuture<GetAnomalyDetectorResponse>()
Expand Down Expand Up @@ -183,7 +190,7 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)

for (GetAnomalyDetectorResponse profileResponse : profileResponses) {
if (profileResponse != null && profileResponse.getDetector() != null) {
String detectorId = profileResponse.getDetector().getId();
String responseDetectorName = profileResponse.getDetector().getName();

// We follow the existing logic as the frontend to determine overall detector state
// https://github.com/opensearch-project/anomaly-detection-dashboards-plugin/blob/main/server/routes/utils/adHelpers.ts#L437
Expand All @@ -192,9 +199,7 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)

if (realtimeTask != null) {
String taskState = realtimeTask.getState();
if (taskState.equalsIgnoreCase("CREATED")) {
detectorState = DetectorStateString.Initializing.name();
} else if (taskState.equalsIgnoreCase("RUNNING")) {
if (taskState.equalsIgnoreCase("CREATED") || taskState.equalsIgnoreCase("RUNNING")) {
detectorState = DetectorStateString.Running.name();
} else if (taskState.equalsIgnoreCase("INIT_FAILURE")
|| taskState.equalsIgnoreCase("UNEXPECTED_FAILURE")
Expand All @@ -203,12 +208,21 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
}
}

if ((Boolean.FALSE.equals(running) && detectorState.equals(DetectorStateString.Running.name()))
|| (Boolean.FALSE.equals(disabled) && detectorState.equals(DetectorStateString.Disabled.name()))
|| (Boolean.FALSE.equals(failed) && detectorState.equals(DetectorStateString.Failed.name()))) {
hitsAsMap.remove(detectorId);
boolean includeRunning = running != null && running == true;
boolean includeFailed = failed != null && failed == true;
boolean isValid = true;

if (detectorState.equals(DetectorStateString.Running.name())) {
isValid = (running == null || running == true) && !(includeFailed && running == null);
} else if (detectorState.equals(DetectorStateString.Failed.name())) {
isValid = (failed == null || failed == true) && !(includeRunning && failed == null);
} else if (detectorState.equals(DetectorStateString.Disabled.name())) {
isValid = (running == null || running == false) && !(includeFailed && running == null);
}

if (!isValid) {
hitsAsMap.remove(responseDetectorName);
}
}
}
}
Expand Down Expand Up @@ -262,6 +276,8 @@ private <T> void processHits(Map<String, SearchHit> hitsAsMap, ActionListener<T>
public static class Factory implements Tool.Factory<SearchAnomalyDetectorsTool> {
private Client client;

private NamedWriteableRegistry namedWriteableRegistry;

private AnomalyDetectionNodeClient adClient;

private static Factory INSTANCE;
Expand All @@ -286,14 +302,15 @@ public static Factory getInstance() {
* Initialize this factory
* @param client The OpenSearch client
*/
public void init(Client client) {
public void init(Client client, NamedWriteableRegistry namedWriteableRegistry) {
this.client = client;
this.adClient = new AnomalyDetectionNodeClient(client);
this.namedWriteableRegistry = namedWriteableRegistry;
this.adClient = new AnomalyDetectionNodeClient(client, namedWriteableRegistry);
}

@Override
public SearchAnomalyDetectorsTool create(Map<String, Object> map) {
return new SearchAnomalyDetectorsTool(client);
return new SearchAnomalyDetectorsTool(client, namedWriteableRegistry);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.agent.tools.utils.ToolConstants;
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.ExistsQueryBuilder;
Expand Down Expand Up @@ -61,9 +62,9 @@ public class SearchAnomalyResultsTool implements Tool {
@Setter
private Parser<?, ?> outputParser;

public SearchAnomalyResultsTool(Client client) {
public SearchAnomalyResultsTool(Client client, NamedWriteableRegistry namedWriteableRegistry) {
this.client = client;
this.adClient = new AnomalyDetectionNodeClient(client);
this.adClient = new AnomalyDetectionNodeClient(client, namedWriteableRegistry);

// probably keep this overridden output parser. need to ensure the output matches what's expected
outputParser = new Parser<>() {
Expand Down Expand Up @@ -190,6 +191,8 @@ private <T> void processHits(SearchHits searchHits, ActionListener<T> listener)
public static class Factory implements Tool.Factory<SearchAnomalyResultsTool> {
private Client client;

private NamedWriteableRegistry namedWriteableRegistry;

private AnomalyDetectionNodeClient adClient;

private static Factory INSTANCE;
Expand All @@ -214,14 +217,15 @@ public static Factory getInstance() {
* Initialize this factory
* @param client The OpenSearch client
*/
public void init(Client client) {
public void init(Client client, NamedWriteableRegistry namedWriteableRegistry) {
this.client = client;
this.adClient = new AnomalyDetectionNodeClient(client);
this.namedWriteableRegistry = namedWriteableRegistry;
this.adClient = new AnomalyDetectionNodeClient(client, namedWriteableRegistry);
}

@Override
public SearchAnomalyResultsTool create(Map<String, Object> map) {
return new SearchAnomalyResultsTool(client);
return new SearchAnomalyResultsTool(client, namedWriteableRegistry);
}

@Override
Expand Down
4 changes: 2 additions & 2 deletions src/test/java/org/opensearch/agent/TestHelpers.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ public static SearchResponse generateSearchResponse(SearchHit[] hits) {
);
}

public static GetAnomalyDetectorResponse generateGetAnomalyDetectorResponses(String[] detectorIds, String[] detectorStates) {
public static GetAnomalyDetectorResponse generateGetAnomalyDetectorResponses(String[] detectorNames, String[] detectorStates) {
AnomalyDetector detector = Mockito.mock(AnomalyDetector.class);
// For each subsequent call to getId(), return the next detectorId in the array
when(detector.getId()).thenReturn(detectorIds[0], Arrays.copyOfRange(detectorIds, 1, detectorIds.length));
when(detector.getName()).thenReturn(detectorNames[0], Arrays.copyOfRange(detectorNames, 1, detectorNames.length));
ADTask realtimeAdTask = Mockito.mock(ADTask.class);
// For each subsequent call to getState(), return the next detectorState in the array
when(realtimeAdTask.getState()).thenReturn(detectorStates[0], Arrays.copyOfRange(detectorStates, 1, detectorStates.length));
Expand Down
Loading

0 comments on commit 44dc232

Please sign in to comment.