diff --git a/src/main/java/org/opensearch/agent/ToolPlugin.java b/src/main/java/org/opensearch/agent/ToolPlugin.java index 41c8f5c3..28987f0c 100644 --- a/src/main/java/org/opensearch/agent/ToolPlugin.java +++ b/src/main/java/org/opensearch/agent/ToolPlugin.java @@ -10,6 +10,8 @@ import java.util.List; import java.util.function.Supplier; +import org.opensearch.agent.common.SkillSettings; +import org.opensearch.agent.tools.CreateAlertTool; import org.opensearch.agent.tools.CreateAnomalyDetectorTool; import org.opensearch.agent.tools.LogPatternTool; import org.opensearch.agent.tools.NeuralSparseSearchTool; @@ -69,6 +71,7 @@ public Collection createComponents( SearchAnomalyDetectorsTool.Factory.getInstance().init(client, namedWriteableRegistry); SearchAnomalyResultsTool.Factory.getInstance().init(client, namedWriteableRegistry); SearchMonitorsTool.Factory.getInstance().init(client); + CreateAlertTool.Factory.getInstance().init(client); CreateAnomalyDetectorTool.Factory.getInstance().init(client); LogPatternTool.Factory.getInstance().init(client, xContentRegistry); return Collections.emptyList(); @@ -86,6 +89,7 @@ public List> getToolFactories() { SearchAnomalyDetectorsTool.Factory.getInstance(), SearchAnomalyResultsTool.Factory.getInstance(), SearchMonitorsTool.Factory.getInstance(), + CreateAlertTool.Factory.getInstance(), CreateAnomalyDetectorTool.Factory.getInstance(), LogPatternTool.Factory.getInstance() ); diff --git a/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java b/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java new file mode 100644 index 00000000..1d6b1c36 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java @@ -0,0 +1,300 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.opensearch.action.support.clustermanager.ClusterManagerNodeRequest.DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT; +import static org.opensearch.ml.common.utils.StringUtils.isJson; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import org.apache.commons.text.StringSubstitutor; +import org.apache.logging.log4j.util.Strings; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.admin.indices.get.GetIndexRequest; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.agent.tools.utils.ToolConstants.ModelType; +import org.opensearch.agent.tools.utils.ToolHelper; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.logging.LoggerMessageFormat; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.common.utils.StringUtils; + +import com.google.gson.reflect.TypeToken; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@ToolAnnotation(CreateAlertTool.TYPE) +public class CreateAlertTool implements Tool { + public static final String TYPE = "CreateAlertTool"; + + private static final String DEFAULT_DESCRIPTION = + "This is a tool that helps to create an alert(i.e. monitor with triggers), some parameters should be parsed based on user's question and context. The parameters should include: \n" + + "1. indices: The input indices of the monitor, should be a list of string in json format.\n"; + + @Setter + @Getter + private String name = TYPE; + @Getter + @Setter + private String description = DEFAULT_DESCRIPTION; + + private final Client client; + private final String modelId; + private final String TOOL_PROMPT_TEMPLATE; + + private static final String MODEL_ID = "model_id"; + private static final String PROMPT_FILE_PATH = "CreateAlertDefaultPrompt.json"; + private static final String DEFAULT_QUESTION = "Create an alert as your recommendation based on the context"; + private static final Map promptDict = ToolHelper.loadDefaultPromptDictFromFile(CreateAlertTool.class, PROMPT_FILE_PATH); + + public CreateAlertTool(Client client, String modelId, String modelType) { + this.client = client; + this.modelId = modelId; + if (!promptDict.containsKey(modelType)) { + throw new IllegalArgumentException( + LoggerMessageFormat + .format( + null, + "Failed to find the right prompt for modelType: {}, this tool supports prompts for these models: [{}]", + modelType, + String.join(",", promptDict.keySet()) + ) + ); + } + TOOL_PROMPT_TEMPLATE = promptDict.get(modelType); + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public String getVersion() { + return null; + } + + @Override + public boolean validate(Map parameters) { + return true; + } + + @Override + public void run(Map parameters, ActionListener listener) { + Map tmpParams = new HashMap<>(parameters); + if (!tmpParams.containsKey("indices") || Strings.isEmpty(tmpParams.get("indices"))) { + throw new IllegalArgumentException( + "No indices in the input parameter. Ask user to " + + "provide index as your final answer directly without using any other tools" + ); + } + String rawIndex = tmpParams.getOrDefault("indices", ""); + Boolean isLocal = Boolean.parseBoolean(tmpParams.getOrDefault("local", "true")); + final GetIndexRequest getIndexRequest = constructIndexRequest(rawIndex, isLocal); + client.admin().indices().getIndex(getIndexRequest, ActionListener.wrap(response -> { + if (response.indices().length == 0) { + throw new IllegalArgumentException( + LoggerMessageFormat + .format( + null, + "Cannot find provided indices {}. Ask " + + "user to check the provided indices as your final answer without using any other " + + "tools", + rawIndex + ) + ); + } + StringBuilder sb = new StringBuilder(); + for (String index : response.indices()) { + sb.append("index: ").append(index).append("\n\n"); + + MappingMetadata mapping = response.mappings().get(index); + if (mapping != null) { + sb.append("mappings:\n"); + for (Entry entry : mapping.sourceAsMap().entrySet()) { + sb.append(entry.getKey()).append("=").append(entry.getValue()).append('\n'); + } + sb.append("\n\n"); + } + } + String mappingInfo = sb.toString(); + ActionRequest request = constructMLPredictRequest(tmpParams, mappingInfo); + client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(r -> { + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) r.getOutput(); + Map dataMap = Optional + .ofNullable(modelTensorOutput.getMlModelOutputs()) + .flatMap(outputs -> outputs.stream().findFirst()) + .flatMap(modelTensors -> modelTensors.getMlModelTensors().stream().findFirst()) + .map(ModelTensor::getDataAsMap) + .orElse(null); + if (dataMap == null) { + throw new IllegalArgumentException("No dataMap returned from LLM."); + } + String alertInfo = ""; + if (dataMap.containsKey("response")) { + alertInfo = (String) dataMap.get("response"); + Pattern jsonPattern = Pattern.compile("```json(.*?)```", Pattern.DOTALL); + Matcher jsonBlockMatcher = jsonPattern.matcher(alertInfo); + if (jsonBlockMatcher.find()) { + alertInfo = jsonBlockMatcher.group(1); + alertInfo = alertInfo.replace("\\\"", "\""); + } + } else { + // LLM sometimes returns the tensor results as a json object directly instead of + // string response, and the json object is stored as a map. + alertInfo = StringUtils.toJson(dataMap); + } + if (!isJson(alertInfo)) { + throw new IllegalArgumentException( + LoggerMessageFormat.format(null, "The response from LLM is not a json: [{}]", alertInfo) + ); + } + listener.onResponse((T) alertInfo); + }, e -> { + log.error("Failed to run model " + modelId, e); + listener.onFailure(e); + })); + }, e -> { + log.error("failed to get index mapping: " + e); + if (e.toString().contains("IndexNotFoundException")) { + listener + .onFailure( + new IllegalArgumentException( + LoggerMessageFormat + .format( + null, + "Cannot find provided indices {}. Ask " + + "user to check the provided indices as your final answer without using any other " + + "tools", + rawIndex + ) + ) + ); + } else { + listener.onFailure(e); + } + })); + } + + private ActionRequest constructMLPredictRequest(Map tmpParams, String mappingInfo) { + tmpParams.put("mapping_info", mappingInfo); + tmpParams.putIfAbsent("indices", ""); + tmpParams.putIfAbsent("chat_history", ""); + tmpParams.putIfAbsent("question", DEFAULT_QUESTION); // In case no question is provided, use a default question. + StringSubstitutor substitute = new StringSubstitutor(tmpParams, "${parameters.", "}"); + String finalToolPrompt = substitute.replace(TOOL_PROMPT_TEMPLATE); + tmpParams.put("prompt", finalToolPrompt); + + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(tmpParams).build(); + ActionRequest request = new MLPredictionTaskRequest( + modelId, + MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build() + ); + return request; + } + + private static GetIndexRequest constructIndexRequest(String rawIndex, Boolean isLocal) { + List indexList; + try { + indexList = StringUtils.gson.fromJson(rawIndex, new TypeToken>() { + }.getType()); + } catch (Exception e) { + // LLM sometimes returns the indices as a string but not json format, although we require that in the tool description. + indexList = Arrays.asList(rawIndex.split("\\.")); + } + if (indexList.isEmpty()) { + throw new IllegalArgumentException( + "The input indices is empty. Ask user to " + "provide index as your final answer directly without using any other tools" + ); + } else if (indexList.stream().anyMatch(index -> index.startsWith("."))) { + throw new IllegalArgumentException( + LoggerMessageFormat + .format( + null, + "The provided indices [{}] contains system index, which is not allowed. Ask user to " + + "check the provided indices as your final answer without using any other.", + rawIndex + ) + ); + } + final String[] indices = indexList.toArray(Strings.EMPTY_ARRAY); + final GetIndexRequest getIndexRequest = new GetIndexRequest() + .indices(indices) + .indicesOptions(IndicesOptions.strictExpand()) + .local(isLocal) + .clusterManagerNodeTimeout(DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT); + return getIndexRequest; + } + + public static class Factory implements Tool.Factory { + + private Client client; + + private static Factory INSTANCE; + + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (CreateAlertTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + public void init(Client client) { + this.client = client; + } + + @Override + public CreateAlertTool create(Map params) { + String modelId = (String) params.get(MODEL_ID); + if (Strings.isBlank(modelId)) { + throw new IllegalArgumentException("model_id cannot be null or blank."); + } + String modelType = (String) params.getOrDefault("model_type", ModelType.CLAUDE.toString()); + return new CreateAlertTool(client, modelId, modelType); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public String getDefaultVersion() { + return null; + } + } +} diff --git a/src/main/java/org/opensearch/agent/tools/utils/ToolConstants.java b/src/main/java/org/opensearch/agent/tools/utils/ToolConstants.java index 2a90ec7e..b5433a0e 100644 --- a/src/main/java/org/opensearch/agent/tools/utils/ToolConstants.java +++ b/src/main/java/org/opensearch/agent/tools/utils/ToolConstants.java @@ -5,6 +5,8 @@ package org.opensearch.agent.tools.utils; +import java.util.Locale; + public class ToolConstants { // Detector state is not cleanly defined on the backend plugin. So, we persist a standard // set of states here for users to interface with when fetching and filtering detectors. @@ -17,6 +19,15 @@ public static enum DetectorStateString { Initializing } + public enum ModelType { + CLAUDE, + OPENAI; + + public static ModelType from(String value) { + return valueOf(value.toUpperCase(Locale.ROOT)); + } + } + // System indices constants are not cleanly exposed from the AD & Alerting plugins, so we persist our // own constants here. public static final String AD_RESULTS_INDEX_PATTERN = ".opendistro-anomaly-results*"; diff --git a/src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java b/src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java index d7f6c3f5..b60f46c9 100644 --- a/src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java +++ b/src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java @@ -5,9 +5,36 @@ package org.opensearch.agent.tools.utils; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; import java.util.Map; +import org.opensearch.ml.common.utils.StringUtils; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 public class ToolHelper { + /** + * Load prompt from the resource file of the invoking class + * @param source class which calls this function + * @param fileName the resource file name of prompt + * @return the LLM request prompt template. + */ + public static Map loadDefaultPromptDictFromFile(Class source, String fileName) { + try (InputStream searchResponseIns = source.getResourceAsStream(fileName)) { + if (searchResponseIns != null) { + String defaultPromptContent = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8); + return StringUtils.gson.fromJson(defaultPromptContent, Map.class); + } + } catch (IOException e) { + log.error("Failed to load default prompt dict from file: {}", fileName, e); + } + return new HashMap<>(); + } + /** * Flatten all the fields in the mappings, insert the field to fieldType mapping to a map * @param mappingSource the mappings of an index diff --git a/src/main/resources/org/opensearch/agent/tools/CreateAlertDefaultPrompt.json b/src/main/resources/org/opensearch/agent/tools/CreateAlertDefaultPrompt.json new file mode 100644 index 00000000..cbd2fe56 --- /dev/null +++ b/src/main/resources/org/opensearch/agent/tools/CreateAlertDefaultPrompt.json @@ -0,0 +1,4 @@ +{ + "CLAUDE": "\n\nHuman:Generate the monitor definition for an alert as the following json format without any other information.\nThe value of `${field_name}` should be a property in the mapping info and `${field_type} should be its type`\nYou could recommend some recommended fields or values for the definition if user does not provide them.\nNo need explanation in the response, it's very important that don't miss ```json.\n\nHuman:RESPONSE FORMAT INSTRUCTIONS\n----------------------------\n```json\n{\n \"name\": \"\" //monitor name\n \"search\": {\n \"indices\": ${parameters.indices} //no need to change.\n \"timeField\": \"${field_name}\", //this field should be date type\n \"bucketValue\": 1, //A numeric value defining the time range for the last, default is 1.\n \"bucketUnitOfTime\": \"m\", //The time unit for the bucketValue, options include 'm' (minutes), 'h' (hours), or 'd' (days), with a default of 'h'\n \"filters\": [ // A list of filters to filter logs to meet user's alert definition.\n {\n \"fieldName\": [ //The length should be 1.\n {\n \"label\": \"${field_name}\",\n \"type\": \"${field_type}\" //options are 'number', 'text', 'keyword', 'boolean'\n }\n ],\n \"fieldValue\": 10,\n \"operator\": \"is\" //options are 'is', 'is_not', 'contains', 'does_not_contains', 'starts_with', 'ends_with', 'is_greater', 'is_greater_equal, 'is_less', 'is_less_equal.\n },\n ],\n \"aggregations\": [\n {\n \"aggregationType\": \"count\", //options are ‘count’, ‘max’, ‘min’, ‘avg’, or ‘sum’, with a default of ‘count’.\n \"fieldName\": \"${field_name}\"\n }\n ]\n },\n \"triggers\": [{ // The triggers for the alert, it triggers when the above aggregation result satisfies the threshold.\n \"name\": 'Trigger' //The name of the trigger. You could generate this name based on this trigger definition.\n \"severity\": 1, //The severity level of the trigger, options are 1, 2 and 3.\n \"thresholdValue\": 0, //A numeric value defining the threshold.\n \"thresholdEnum\": \"ABOVE\" //options are 'ABOVE', 'BELOW', or 'EXACTLY'.\n }],\n}\n```\nHuman: Examples\n--------------------\nquestion: create alert if the count of non 200 response happens over 30 times per hour.\nresponse: {\\n\"name\": \"Error Response Alert\",\\n\"search\": {\\n\"indices\": [\"opensearch_dashboards_sample_data_logs\"],\\n\"timeField\": \"timestamp\",\\n\"bucketValue\": 60,\\n\"bucketUnitOfTime\": \"m\",\\n\"filters\": [\\n{\\n\"fieldName\": [\\n{\\n\"label\": \"response\",\\n\"type\": \"text\"\\n}\\n],\\n\"fieldValue\": \"200\",\\n\"operator\": \"is_not\"\\n}\\n],\\n\"aggregations\": [\\n{\\n\"aggregationType\": \"count\",\\n\"fieldName\": \"bytes\"\\n}\\n]\\n},\\n\"triggers\": [{\\n\"name\": \"Error Response Count Above 30\", \\n\"value\": 30,\\n\"enum\": \"ABOVE\"\\n}]\\n}\nHuman:USER'S CONTEXT\n--------------------\nmapping_info of the target index: ${parameters.mapping_info}\nHuman:CHAT HISTORY\n--------------------\n${parameters.chat_history}\nHuman:USER'S INPUT\n--------------------\nHere is the user's input :\n${parameters.question}\n\nAssistant:", + "OPENAI": "Generate the monitor definition for an alert as the following json format without any other information.\nThe value of `${field_name}` should be a property in the mapping info and `${field_type} should be its type`\nYou could recommend some recommended fields or values for the definition if user does not provide them.\nNo need explanation in the response, it's very important that don't miss ```json.\n\nHuman:RESPONSE FORMAT INSTRUCTIONS\n----------------------------\n```json\n{\n \"name\": \"\" //monitor name\n \"search\": {\n \"indices\": ${parameters.indices} //no need to change.\n \"timeField\": \"${field_name}\",\n \"bucketValue\": 1, //A numeric value defining the time range for the last, default is 1.\n \"bucketUnitOfTime\": \"m\", //The time unit for the bucketValue, options include 'm' (minutes), 'h' (hours), or 'd' (days), with a default of 'h'\n \"filters\": [ // A list of filters to filter logs to meet user's alert definition.\n {\n \"fieldName\": [ //The length should be 1.\n {\n \"label\": \"${field_name}\",\n \"type\": \"${field_type}\" //options are 'number', 'text', 'keyword', 'boolean'\n }\n ],\n \"fieldValue\": 10,\n \"operator\": \"is\" //options are 'is', 'is_not', 'contains', 'does_not_contains', 'starts_with', 'ends_with', 'is_greater', 'is_greater_equal, 'is_less', 'is_less_equal.\n },\n ],\n \"aggregations\": [\n {\n \"aggregationType\": \"count\", //options are ‘count’, ‘max’, ‘min’, ‘avg’, or ‘sum’, with a default of ‘count’.\n \"fieldName\": \"${field_name}\"\n }\n ]\n },\n \"triggers\": [{ // The triggers for the alert, it triggers when the above aggregation result satisfies the threshold.\n \"name\": 'Trigger' //The name of the trigger. You could generate this name based on this trigger definition.\n \"severity\": 1, //The severity level of the trigger, options are 1, 2 and 3.\n \"thresholdValue\": 0, //A numeric value defining the threshold.\n \"thresholdEnum\": \"ABOVE\" //options are 'ABOVE', 'BELOW', or 'EXACTLY'.\n }],\n}\n```\nHuman: Examples\n--------------------\nquestion: create alert if the count of non 200 response happens over 30 times per hour.\nresponse: {\\n\"name\": \"Error Response Alert\",\\n\"search\": {\\n\"indices\": [\"opensearch_dashboards_sample_data_logs\"],\\n\"timeField\": \"timestamp\",\\n\"bucketValue\": 60,\\n\"bucketUnitOfTime\": \"m\",\\n\"filters\": [\\n{\\n\"fieldName\": [\\n{\\n\"label\": \"response\",\\n\"type\": \"text\"\\n}\\n],\\n\"fieldValue\": \"200\",\\n\"operator\": \"is_not\"\\n}\\n],\\n\"aggregations\": [\\n{\\n\"aggregationType\": \"count\",\\n\"fieldName\": \"bytes\"\\n}\\n]\\n},\\n\"triggers\": [{\\n\"name\": \"Error Response Count Above 30\", \\n\"value\": 30,\\n\"enum\": \"ABOVE\"\\n}]\\n}\nHuman:USER'S CONTEXT\n--------------------\nmapping_info of the target index: ${parameters.mapping_info}\nHuman:CHAT HISTORY\n--------------------\n${parameters.chat_history}\nHuman:USER'S INPUT\n--------------------\nHere is the user's input :\n${parameters.question}" +} diff --git a/src/test/java/org/opensearch/agent/tools/CreateAlertToolTests.java b/src/test/java/org/opensearch/agent/tools/CreateAlertToolTests.java new file mode 100644 index 00000000..f0b6a245 --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/CreateAlertToolTests.java @@ -0,0 +1,382 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.common.utils.StringUtils.isJson; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ExecutionException; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.admin.indices.get.GetIndexResponse; +import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.core.action.ActionListener; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; + +import com.google.common.collect.ImmutableMap; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class CreateAlertToolTests { + private final Client client = mock(Client.class); + @Mock + private AdminClient adminClient; + @Mock + private IndicesAdminClient indicesAdminClient; + @Mock + private GetMappingsResponse getMappingsResponse; + @Mock + private MappingMetadata mappingMetadata; + private Map mockedMappings; + private Map indexMappings; + @Mock + private MLTaskResponse mlTaskResponse; + @Mock + private ModelTensorOutput modelTensorOutput; + @Mock + private ModelTensors modelTensors; + @Mock + private ActionFuture actionFuture; + @Mock + private GetIndexResponse getIndexResponse; + + private final String jsonResponse = "{\"name\":\"mocked_response\"}"; + private final String mockedIndexName = "mocked_index_name"; + private final String mockedIndices = String.format("[%s]", mockedIndexName); + private CreateAlertTool tool; + + @Before + public void setup() throws ExecutionException, InterruptedException { + MockitoAnnotations.openMocks(this); + createMappings(); + when(client.admin()).thenReturn(adminClient); + when(adminClient.indices()).thenReturn(indicesAdminClient); + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(getIndexResponse); + return null; + }).when(indicesAdminClient).getIndex(any(), any()); + + when(getIndexResponse.indices()).thenReturn(new String[] { mockedIndexName }); + when(getIndexResponse.mappings()).thenReturn(mockedMappings); + when(mappingMetadata.getSourceAsMap()).thenReturn(indexMappings); + + CreateAlertTool.Factory.getInstance().init(client); + tool = CreateAlertTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId")); + assertEquals(CreateAlertTool.TYPE, tool.getName()); + } + + private void createMappings() { + indexMappings = new HashMap<>(); + indexMappings + .put( + "properties", + ImmutableMap + .of( + "field1", + ImmutableMap.of("type", "integer"), + "field2", + ImmutableMap.of("type", "float"), + "field3", + ImmutableMap.of("type", "date") + ) + ); + mockedMappings = new HashMap<>(); + mockedMappings.put(mockedIndexName, mappingMetadata); + } + + private void initMLTensors(String response) { + Map modelReturns = Collections.singletonMap("response", response); + initMLTensors(modelReturns); + } + + private void initMLTensorsWithoutResponse(String response) { + assert (isJson(response)); + Map modelReturns = gson.fromJson(response, Map.class); + initMLTensors(modelReturns); + } + + private void initMLTensors(Map modelReturns) { + ModelTensor modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns); + when(modelTensors.getMlModelTensors()).thenReturn(Collections.singletonList(modelTensor)); + when(modelTensorOutput.getMlModelOutputs()).thenReturn(Collections.singletonList(modelTensors)); + when(mlTaskResponse.getOutput()).thenReturn(modelTensorOutput); + + // call model + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + } + + @Test + public void testTool_WithoutModelId() { + Exception exception = assertThrows( + IllegalArgumentException.class, + () -> CreateAlertTool.Factory.getInstance().create(Collections.emptyMap()) + ); + assertEquals("model_id cannot be null or blank.", exception.getMessage()); + } + + @Test + public void testTool_WithBlankModelId() { + Exception exception = assertThrows( + IllegalArgumentException.class, + () -> CreateAlertTool.Factory.getInstance().create(ImmutableMap.of("model_id", " ")) + ); + assertEquals("model_id cannot be null or blank.", exception.getMessage()); + } + + @Test + public void testTool_WithNonSupportedModelType() { + Exception exception = assertThrows( + IllegalArgumentException.class, + () -> CreateAlertTool.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId", "model_type", "non_supported_modelType")) + ); + assertEquals( + "Failed to find the right prompt for modelType: non_supported_modelType, this tool supports prompts for these models: [CLAUDE,OPENAI]", + exception.getMessage() + ); + } + + @Test + public void testTool() { + // test json response + initMLTensors(jsonResponse); + tool + .run( + ImmutableMap.of("indices", mockedIndices, "question", "test_question"), + ActionListener + .wrap(response -> assertEquals(jsonResponse, response), e -> fail("Tool runs failed: " + e.getMessage())) + ); + + // test text response wrapping json + final String textResponseWithJson = String.format("RESPONSE_HEADER\n Tool output: ```json%s```, RESPONSE_FOOTER\n", jsonResponse); + initMLTensors(textResponseWithJson); + tool + .run( + ImmutableMap.of("indices", mockedIndices, "question", "test_question"), + ActionListener + .wrap(response -> assertEquals(jsonResponse, response), e -> fail("Tool runs failed: " + e.getMessage())) + ); + + // test tensor result without a string response but a json object directly. + initMLTensorsWithoutResponse(jsonResponse); + tool + .run( + ImmutableMap.of("indices", mockedIndices, "question", "test_question"), + ActionListener + .wrap(response -> assertEquals(jsonResponse, response), e -> fail("Tool runs failed: " + e.getMessage())) + ); + } + + @Test + public void testToolWithIndicesNotInJsonFormat() { + // test indices no in json format + initMLTensors(jsonResponse); + tool + .run( + ImmutableMap.of("indices", mockedIndexName, "question", "test_question"), + ActionListener + .wrap(response -> assertEquals(jsonResponse, response), e -> fail("Tool runs failed: " + e.getMessage())) + ); + + tool + .run( + ImmutableMap.of("indices", mockedIndexName + "," + mockedIndexName, "question", "test_question"), + ActionListener + .wrap(response -> assertEquals(jsonResponse, response), e -> fail("Tool runs failed: " + e.getMessage())) + ); + + } + + @Test + public void testToolWithNoJsonResponse() { + String noJsonResponse = "No json response"; + initMLTensors(noJsonResponse); + Exception exception = assertThrows( + IllegalArgumentException.class, + () -> tool + .run( + ImmutableMap.of("indices", mockedIndices, "question", "test_question"), + ActionListener.wrap(response -> assertEquals(noJsonResponse, response), e -> { + throw new IllegalArgumentException(e.getMessage()); + }) + ) + ); + assertEquals(String.format("The response from LLM is not a json: [%s]", noJsonResponse), exception.getMessage()); + + final String textResponseWithJson = String.format("RESPONSE_HEADER\n Tool output: ```json%s```, RESPONSE_FOOTER\n", noJsonResponse); + initMLTensors(textResponseWithJson); + Exception exception2 = assertThrows( + IllegalArgumentException.class, + () -> tool + .run( + ImmutableMap.of("indices", mockedIndices, "question", "test_question"), + ActionListener.wrap(response -> assertEquals(noJsonResponse, response), e -> { + throw new IllegalArgumentException(e.getMessage()); + }) + ) + ); + assertEquals(String.format("The response from LLM is not a json: [%s]", noJsonResponse), exception2.getMessage()); + + } + + @Test + public void testToolWithPredictModelFailed() { + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new Exception("Failed to predict")); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + + Exception exception = assertThrows( + RuntimeException.class, + () -> tool + .run( + ImmutableMap.of("indices", mockedIndices, "question", "test_question"), + ActionListener.wrap(response -> assertEquals(jsonResponse, response), e -> { + throw new RuntimeException(e.getMessage()); + }) + ) + ); + assertEquals("Failed to predict", exception.getMessage()); + } + + @Test + public void testToolWithIllegalIndices() { + // no indices in input parameters + Exception exception = assertThrows( + RuntimeException.class, + () -> tool + .run( + ImmutableMap.of("question", "test_question"), + ActionListener.wrap(response -> assertEquals(jsonResponse, response), e -> { + throw new RuntimeException(e.getMessage()); + }) + ) + ); + assertEquals( + "No indices in the input parameter. Ask user to provide index as your final answer directly without using any other tools", + exception.getMessage() + ); + + // empty string as indices + exception = assertThrows( + RuntimeException.class, + () -> tool + .run( + ImmutableMap.of("indices", "", "question", "test_question"), + ActionListener.wrap(response -> assertEquals(jsonResponse, response), e -> { + throw new RuntimeException(e.getMessage()); + }) + ) + ); + assertEquals( + "No indices in the input parameter. Ask user to provide index as your final answer directly without using any other tools", + exception.getMessage() + ); + + // indices is an empty list + exception = assertThrows( + RuntimeException.class, + () -> tool + .run( + ImmutableMap.of("indices", "[]", "question", "test_question"), + ActionListener.wrap(response -> assertEquals(jsonResponse, response), e -> { + throw new RuntimeException(e.getMessage()); + }) + ) + ); + assertEquals( + "The input indices is empty. Ask user to provide index as your final answer directly without using any other tools", + exception.getMessage() + ); + + // indices contain system index + exception = assertThrows( + RuntimeException.class, + () -> tool + .run( + ImmutableMap.of("indices", "[.kibana]", "question", "test_question"), + ActionListener.wrap(response -> assertEquals(jsonResponse, response), e -> { + throw new RuntimeException(e.getMessage()); + }) + ) + ); + assertEquals( + "The provided indices [[.kibana]] contains system index, which is not allowed. Ask user to check the provided indices as your final answer without using any other.", + exception.getMessage() + ); + + // Cannot find provided indices in opensearch + when(getIndexResponse.indices()).thenReturn(new String[] {}); + exception = assertThrows( + RuntimeException.class, + () -> tool + .run( + ImmutableMap.of("indices", "[non_existed_index]", "question", "test_question"), + ActionListener.wrap(response -> assertEquals(jsonResponse, response), e -> { + throw new RuntimeException(e.getMessage()); + }) + ) + ); + assertEquals( + "Cannot find provided indices [non_existed_index]. Ask user to check the provided indices as your final answer without using any other tools", + exception.getMessage() + ); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onFailure(new IndexNotFoundException("no such index")); + return null; + }).when(indicesAdminClient).getIndex(any(), any()); + + exception = assertThrows( + RuntimeException.class, + () -> tool + .run( + ImmutableMap.of("indices", "[non_existed_index]", "question", "test_question"), + ActionListener.wrap(response -> assertEquals(jsonResponse, response), e -> { + throw new RuntimeException(e.getMessage()); + }) + ) + ); + assertEquals( + "Cannot find provided indices [non_existed_index]. Ask user to check the provided indices as your final answer without using any other tools", + exception.getMessage() + ); + } +} diff --git a/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java b/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java index 658a3fc7..0525dea5 100644 --- a/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java +++ b/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java @@ -6,6 +6,8 @@ package org.opensearch.integTest; import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; import java.util.Collections; import java.util.List; import java.util.Locale; @@ -388,4 +390,12 @@ public static Response makeRequest( } return client.performRequest(request); } + + @SneakyThrows + protected String registerAgent(String modelId, String requestBodyResourceFile) { + String registerAgentRequestBody = Files + .readString(Path.of(this.getClass().getClassLoader().getResource(requestBodyResourceFile).toURI())); + registerAgentRequestBody = registerAgentRequestBody.replace("", modelId); + return createAgent(registerAgentRequestBody); + } } diff --git a/src/test/java/org/opensearch/integTest/CreateAlertToolIT.java b/src/test/java/org/opensearch/integTest/CreateAlertToolIT.java new file mode 100644 index 00000000..be93606c --- /dev/null +++ b/src/test/java/org/opensearch/integTest/CreateAlertToolIT.java @@ -0,0 +1,152 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import static org.hamcrest.Matchers.containsString; + +import java.io.IOException; +import java.util.List; + +import org.hamcrest.MatcherAssert; +import org.junit.Before; +import org.opensearch.agent.tools.CreateAlertTool; +import org.opensearch.client.ResponseException; + +import lombok.SneakyThrows; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class CreateAlertToolIT extends ToolIntegrationTest { + private final String requestBodyResourceFile = "org/opensearch/agent/tools/register_flow_agent_of_create_alert_tool_request_body.json"; + private final String NORMAL_INDEX = "normal_index"; + private final String NON_EXISTENT_INDEX = "non-existent"; + private final String SYSTEM_INDEX = ".kibana"; + + private final String alertJson = + "{\"name\":\"Error 500 Response Alert\",\"search\":{\"indices\":[\"opensearch_dashboards_sample_data_logs\"],\"timeField\":\"timestamp\",\"bucketValue\":60.0,\"bucketUnitOfTime\":\"m\",\"filters\":[{\"fieldName\":[{\"label\":\"response\",\"type\":\"text\"}],\"fieldValue\":\"500\",\"operator\":\"is\"}],\"aggregations\":[{\"aggregationType\":\"count\",\"fieldName\":\"bytes\"}]},\"triggers\":[{\"name\":\"Error 500 Response Count Above 1\",\"severity\":1.0,\"thresholdValue\":1.0,\"thresholdEnum\":\"ABOVE\"}]}"; + private final String question = "Create alert on the index when count of peoples whose age greater than 50 exceeds 10"; + private final String pureJsonResponseIndicator = "$PURE_JSON"; + private final String noJsonResponseIndicator = "$NO_JSON"; + + private String agentId; + + @Before + public void registerAgent() throws IOException, InterruptedException { + agentId = registerAgent(modelId, requestBodyResourceFile); + } + + @Override + List promptHandlers() { + PromptHandler CreateAlertHandler = new PromptHandler() { + @Override + String response(String prompt) { + if (prompt.contains(pureJsonResponseIndicator)) { + return alertJson; + } else if (prompt.contains(noJsonResponseIndicator)) { + return "No json response"; + } + return "This is output: ```json" + alertJson + "```"; + } + + @Override + boolean apply(String prompt) { + return true; + } + }; + return List.of(CreateAlertHandler); + } + + @Override + String toolType() { + return CreateAlertTool.TYPE; + } + + @SneakyThrows + public void testCreateAlertTool() { + prepareIndex(); + String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question, NORMAL_INDEX); + String result = executeAgent(agentId, requestBody); + assertEquals(alertJson, result); + } + + public void testCreateAlertToolWithPureJsonResponse() { + prepareIndex(); + String requestBody = String + .format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question + pureJsonResponseIndicator, NORMAL_INDEX); + String result = executeAgent(agentId, requestBody); + assertEquals(alertJson, result); + } + + public void testCreateAlertToolWithNoJsonResponse() { + prepareIndex(); + String requestBody = String + .format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question + noJsonResponseIndicator, NORMAL_INDEX); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, requestBody)); + MatcherAssert.assertThat(exception.getMessage(), containsString("The response from LLM is not a json")); + } + + public void testCreateAlertToolWithNonExistentModelId() { + prepareIndex(); + String abnormalAgentId = registerAgent("NON_EXISTENT_MODEL_ID", requestBodyResourceFile); + String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question, NORMAL_INDEX); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(abnormalAgentId, requestBody)); + MatcherAssert.assertThat(exception.getMessage(), containsString("Failed to find model")); + } + + public void testCreateAlertToolWithNonExistentIndex() { + prepareIndex(); + String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question, NON_EXISTENT_INDEX); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, requestBody)); + MatcherAssert + .assertThat( + exception.getMessage(), + containsString( + "Cannot find provided indices [non-existent]. Ask user to check the provided indices as your final answer without using any other tools" + ) + ); + } + + public void testCreateAlertToolWithSystemIndex() { + prepareIndex(); + String agentId = registerAgent(modelId, requestBodyResourceFile); + String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"[%s]\"}}", question, SYSTEM_INDEX); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, requestBody)); + MatcherAssert.assertThat(exception.getMessage(), containsString("contains system index, which is not allowed")); + } + + public void testCreateAlertToolWithEmptyIndex() { + prepareIndex(); + String agentId = registerAgent(modelId, requestBodyResourceFile); + String requestBody = String.format("{\"parameters\": {\"question\": \"%s\", \"indices\": \"\"}}", question); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, requestBody)); + MatcherAssert.assertThat(exception.getMessage(), containsString("No indices in the input parameter")); + } + + @SneakyThrows + private void prepareIndex() { + createIndexWithConfiguration( + NORMAL_INDEX, + "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"response\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"bytes\": {\n" + + " \"type\": \"long\"\n" + + " },\n" + + " \"timestamp\": {\n" + + " \"type\": \"date\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}" + ); + addDocToIndex(NORMAL_INDEX, "0", List.of("response", "bytes", "timestamp"), List.of(200, 1, "2024-07-03T10:22:56,520")); + addDocToIndex(NORMAL_INDEX, "1", List.of("response", "bytes", "timestamp"), List.of(200, 2, "2024-07-03T10:22:57,520")); + } + +} diff --git a/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_create_alert_tool_request_body.json b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_create_alert_tool_request_body.json new file mode 100644 index 00000000..e326e80c --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_create_alert_tool_request_body.json @@ -0,0 +1,12 @@ +{ + "name": "Test_create_alert_flow_agent", + "type": "flow", + "tools": [ + { + "type": "CreateAlertTool", + "parameters": { + "model_id": "" + } + } + ] +}