Skip to content

Commit

Permalink
add CreateAlertTool (#349)
Browse files Browse the repository at this point in the history
* add CreateAlertTool

Signed-off-by: Heng Qian <[email protected]>

* spotlessApply and address comments

Signed-off-by: Heng Qian <[email protected]>

* add IT for CreateAlertTool

Signed-off-by: Heng Qian <[email protected]>

* address comments

Signed-off-by: Heng Qian <[email protected]>

* fix after merging main

Signed-off-by: Heng Qian <[email protected]>

* fix forbidden API

Signed-off-by: Heng Qian <[email protected]>

* fix IT failure

Signed-off-by: Heng Qian <[email protected]>

* run spotlessCheck

Signed-off-by: Heng Qian <[email protected]>

* Address comments

Signed-off-by: Heng Qian <[email protected]>

* Address comments of changing getIndex to use ActionListener

Signed-off-by: Heng Qian <[email protected]>

* fix IT

Signed-off-by: Heng Qian <[email protected]>

* make prompt dict static

Signed-off-by: Heng Qian <[email protected]>

---------

Signed-off-by: Heng Qian <[email protected]>
  • Loading branch information
qianheng-aws authored Aug 1, 2024
1 parent 4f30964 commit 798065f
Show file tree
Hide file tree
Showing 9 changed files with 901 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/main/java/org/opensearch/agent/ToolPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.function.Supplier;

import org.opensearch.agent.common.SkillSettings;
import org.opensearch.agent.tools.CreateAlertTool;
import org.opensearch.agent.tools.CreateAnomalyDetectorTool;
import org.opensearch.agent.tools.NeuralSparseSearchTool;
import org.opensearch.agent.tools.PPLTool;
Expand Down Expand Up @@ -74,6 +75,7 @@ public Collection<Object> createComponents(
SearchAnomalyDetectorsTool.Factory.getInstance().init(client, namedWriteableRegistry);
SearchAnomalyResultsTool.Factory.getInstance().init(client, namedWriteableRegistry);
SearchMonitorsTool.Factory.getInstance().init(client);
CreateAlertTool.Factory.getInstance().init(client);
CreateAnomalyDetectorTool.Factory.getInstance().init(client);
return Collections.emptyList();
}
Expand All @@ -90,6 +92,7 @@ public List<Tool.Factory<? extends Tool>> getToolFactories() {
SearchAnomalyDetectorsTool.Factory.getInstance(),
SearchAnomalyResultsTool.Factory.getInstance(),
SearchMonitorsTool.Factory.getInstance(),
CreateAlertTool.Factory.getInstance(),
CreateAnomalyDetectorTool.Factory.getInstance()
);
}
Expand Down
300 changes: 300 additions & 0 deletions src/main/java/org/opensearch/agent/tools/CreateAlertTool.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.agent.tools;

import static org.opensearch.action.support.clustermanager.ClusterManagerNodeRequest.DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT;
import static org.opensearch.ml.common.utils.StringUtils.isJson;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.apache.commons.text.StringSubstitutor;
import org.apache.logging.log4j.util.Strings;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.admin.indices.get.GetIndexRequest;
import org.opensearch.action.support.IndicesOptions;
import org.opensearch.agent.tools.utils.ToolConstants.ModelType;
import org.opensearch.agent.tools.utils.ToolHelper;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.MappingMetadata;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.logging.LoggerMessageFormat;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;

import com.google.gson.reflect.TypeToken;

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

@Log4j2
@ToolAnnotation(CreateAlertTool.TYPE)
public class CreateAlertTool implements Tool {
public static final String TYPE = "CreateAlertTool";

private static final String DEFAULT_DESCRIPTION =
"This is a tool that helps to create an alert(i.e. monitor with triggers), some parameters should be parsed based on user's question and context. The parameters should include: \n"
+ "1. indices: The input indices of the monitor, should be a list of string in json format.\n";

@Setter
@Getter
private String name = TYPE;
@Getter
@Setter
private String description = DEFAULT_DESCRIPTION;

private final Client client;
private final String modelId;
private final String TOOL_PROMPT_TEMPLATE;

private static final String MODEL_ID = "model_id";
private static final String PROMPT_FILE_PATH = "CreateAlertDefaultPrompt.json";
private static final String DEFAULT_QUESTION = "Create an alert as your recommendation based on the context";
private static final Map<String, String> promptDict = ToolHelper.loadDefaultPromptDictFromFile(CreateAlertTool.class, PROMPT_FILE_PATH);

public CreateAlertTool(Client client, String modelId, String modelType) {
this.client = client;
this.modelId = modelId;
if (!promptDict.containsKey(modelType)) {
throw new IllegalArgumentException(
LoggerMessageFormat
.format(
null,
"Failed to find the right prompt for modelType: {}, this tool supports prompts for these models: [{}]",
modelType,
String.join(",", promptDict.keySet())
)
);
}
TOOL_PROMPT_TEMPLATE = promptDict.get(modelType);
}

@Override
public String getType() {
return TYPE;
}

@Override
public String getVersion() {
return null;
}

@Override
public boolean validate(Map<String, String> parameters) {
return true;
}

@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
Map<String, String> tmpParams = new HashMap<>(parameters);
if (!tmpParams.containsKey("indices") || Strings.isEmpty(tmpParams.get("indices"))) {
throw new IllegalArgumentException(
"No indices in the input parameter. Ask user to "
+ "provide index as your final answer directly without using any other tools"
);
}
String rawIndex = tmpParams.getOrDefault("indices", "");
Boolean isLocal = Boolean.parseBoolean(tmpParams.getOrDefault("local", "true"));
final GetIndexRequest getIndexRequest = constructIndexRequest(rawIndex, isLocal);
client.admin().indices().getIndex(getIndexRequest, ActionListener.wrap(response -> {
if (response.indices().length == 0) {
throw new IllegalArgumentException(
LoggerMessageFormat
.format(
null,
"Cannot find provided indices {}. Ask "
+ "user to check the provided indices as your final answer without using any other "
+ "tools",
rawIndex
)
);
}
StringBuilder sb = new StringBuilder();
for (String index : response.indices()) {
sb.append("index: ").append(index).append("\n\n");

MappingMetadata mapping = response.mappings().get(index);
if (mapping != null) {
sb.append("mappings:\n");
for (Entry<String, Object> entry : mapping.sourceAsMap().entrySet()) {
sb.append(entry.getKey()).append("=").append(entry.getValue()).append('\n');
}
sb.append("\n\n");
}
}
String mappingInfo = sb.toString();
ActionRequest request = constructMLPredictRequest(tmpParams, mappingInfo);
client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(r -> {
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) r.getOutput();
Map<String, ?> dataMap = Optional
.ofNullable(modelTensorOutput.getMlModelOutputs())
.flatMap(outputs -> outputs.stream().findFirst())
.flatMap(modelTensors -> modelTensors.getMlModelTensors().stream().findFirst())
.map(ModelTensor::getDataAsMap)
.orElse(null);
if (dataMap == null) {
throw new IllegalArgumentException("No dataMap returned from LLM.");
}
String alertInfo = "";
if (dataMap.containsKey("response")) {
alertInfo = (String) dataMap.get("response");
Pattern jsonPattern = Pattern.compile("```json(.*?)```", Pattern.DOTALL);
Matcher jsonBlockMatcher = jsonPattern.matcher(alertInfo);
if (jsonBlockMatcher.find()) {
alertInfo = jsonBlockMatcher.group(1);
alertInfo = alertInfo.replace("\\\"", "\"");
}
} else {
// LLM sometimes returns the tensor results as a json object directly instead of
// string response, and the json object is stored as a map.
alertInfo = StringUtils.toJson(dataMap);
}
if (!isJson(alertInfo)) {
throw new IllegalArgumentException(
LoggerMessageFormat.format(null, "The response from LLM is not a json: [{}]", alertInfo)
);
}
listener.onResponse((T) alertInfo);
}, e -> {
log.error("Failed to run model " + modelId, e);
listener.onFailure(e);
}));
}, e -> {
log.error("failed to get index mapping: " + e);
if (e.toString().contains("IndexNotFoundException")) {
listener
.onFailure(
new IllegalArgumentException(
LoggerMessageFormat
.format(
null,
"Cannot find provided indices {}. Ask "
+ "user to check the provided indices as your final answer without using any other "
+ "tools",
rawIndex
)
)
);
} else {
listener.onFailure(e);
}
}));
}

private ActionRequest constructMLPredictRequest(Map<String, String> tmpParams, String mappingInfo) {
tmpParams.put("mapping_info", mappingInfo);
tmpParams.putIfAbsent("indices", "");
tmpParams.putIfAbsent("chat_history", "");
tmpParams.putIfAbsent("question", DEFAULT_QUESTION); // In case no question is provided, use a default question.
StringSubstitutor substitute = new StringSubstitutor(tmpParams, "${parameters.", "}");
String finalToolPrompt = substitute.replace(TOOL_PROMPT_TEMPLATE);
tmpParams.put("prompt", finalToolPrompt);

RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(tmpParams).build();
ActionRequest request = new MLPredictionTaskRequest(
modelId,
MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()
);
return request;
}

private static GetIndexRequest constructIndexRequest(String rawIndex, Boolean isLocal) {
List<String> indexList;
try {
indexList = StringUtils.gson.fromJson(rawIndex, new TypeToken<List<String>>() {
}.getType());
} catch (Exception e) {
// LLM sometimes returns the indices as a string but not json format, although we require that in the tool description.
indexList = Arrays.asList(rawIndex.split("\\."));
}
if (indexList.isEmpty()) {
throw new IllegalArgumentException(
"The input indices is empty. Ask user to " + "provide index as your final answer directly without using any other tools"
);
} else if (indexList.stream().anyMatch(index -> index.startsWith("."))) {
throw new IllegalArgumentException(
LoggerMessageFormat
.format(
null,
"The provided indices [{}] contains system index, which is not allowed. Ask user to "
+ "check the provided indices as your final answer without using any other.",
rawIndex
)
);
}
final String[] indices = indexList.toArray(Strings.EMPTY_ARRAY);
final GetIndexRequest getIndexRequest = new GetIndexRequest()
.indices(indices)
.indicesOptions(IndicesOptions.strictExpand())
.local(isLocal)
.clusterManagerNodeTimeout(DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT);
return getIndexRequest;
}

public static class Factory implements Tool.Factory<CreateAlertTool> {

private Client client;

private static Factory INSTANCE;

public static Factory getInstance() {
if (INSTANCE != null) {
return INSTANCE;
}
synchronized (CreateAlertTool.class) {
if (INSTANCE != null) {
return INSTANCE;
}
INSTANCE = new Factory();
return INSTANCE;
}
}

public void init(Client client) {
this.client = client;
}

@Override
public CreateAlertTool create(Map<String, Object> params) {
String modelId = (String) params.get(MODEL_ID);
if (Strings.isBlank(modelId)) {
throw new IllegalArgumentException("model_id cannot be null or blank.");
}
String modelType = (String) params.getOrDefault("model_type", ModelType.CLAUDE.toString());
return new CreateAlertTool(client, modelId, modelType);
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}

@Override
public String getDefaultType() {
return TYPE;
}

@Override
public String getDefaultVersion() {
return null;
}
}
}
11 changes: 11 additions & 0 deletions src/main/java/org/opensearch/agent/tools/utils/ToolConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.agent.tools.utils;

import java.util.Locale;

public class ToolConstants {
// Detector state is not cleanly defined on the backend plugin. So, we persist a standard
// set of states here for users to interface with when fetching and filtering detectors.
Expand All @@ -17,6 +19,15 @@ public static enum DetectorStateString {
Initializing
}

public enum ModelType {
CLAUDE,
OPENAI;

public static ModelType from(String value) {
return valueOf(value.toUpperCase(Locale.ROOT));
}
}

// System indices constants are not cleanly exposed from the AD & Alerting plugins, so we persist our
// own constants here.
public static final String AD_RESULTS_INDEX_PATTERN = ".opendistro-anomaly-results*";
Expand Down
27 changes: 27 additions & 0 deletions src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,36 @@

package org.opensearch.agent.tools.utils;

import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;

import org.opensearch.ml.common.utils.StringUtils;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class ToolHelper {
/**
* Load prompt from the resource file of the invoking class
* @param source class which calls this function
* @param fileName the resource file name of prompt
* @return the LLM request prompt template.
*/
public static Map<String, String> loadDefaultPromptDictFromFile(Class<?> source, String fileName) {
try (InputStream searchResponseIns = source.getResourceAsStream(fileName)) {
if (searchResponseIns != null) {
String defaultPromptContent = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8);
return StringUtils.gson.fromJson(defaultPromptContent, Map.class);
}
} catch (IOException e) {
log.error("Failed to load default prompt dict from file: {}", fileName, e);
}
return new HashMap<>();
}

/**
* Flatten all the fields in the mappings, insert the field to fieldType mapping to a map
* @param mappingSource the mappings of an index
Expand Down
Loading

0 comments on commit 798065f

Please sign in to comment.