diff --git a/src/main/java/org/opensearch/agent/ToolPlugin.java b/src/main/java/org/opensearch/agent/ToolPlugin.java index db07ac0b..d5c418ae 100644 --- a/src/main/java/org/opensearch/agent/ToolPlugin.java +++ b/src/main/java/org/opensearch/agent/ToolPlugin.java @@ -10,6 +10,7 @@ import java.util.List; import java.util.function.Supplier; +import org.opensearch.agent.common.SkillSettings; import org.opensearch.agent.tools.NeuralSparseSearchTool; import org.opensearch.agent.tools.PPLTool; import org.opensearch.agent.tools.RAGTool; @@ -18,9 +19,12 @@ import org.opensearch.agent.tools.SearchAnomalyResultsTool; import org.opensearch.agent.tools.SearchMonitorsTool; import org.opensearch.agent.tools.VectorDBTool; +import org.opensearch.agent.tools.utils.ClusterSettingHelper; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.env.Environment; @@ -59,8 +63,9 @@ public Collection createComponents( this.client = client; this.clusterService = clusterService; this.xContentRegistry = xContentRegistry; - - PPLTool.Factory.getInstance().init(client); + Settings settings = environment.settings(); + ClusterSettingHelper clusterSettingHelper = new ClusterSettingHelper(settings, clusterService); + PPLTool.Factory.getInstance().init(client, clusterSettingHelper); NeuralSparseSearchTool.Factory.getInstance().init(client, xContentRegistry); VectorDBTool.Factory.getInstance().init(client, xContentRegistry); RAGTool.Factory.getInstance().init(client, xContentRegistry); @@ -85,4 +90,9 @@ public List> getToolFactories() { SearchMonitorsTool.Factory.getInstance() ); } + + @Override + public List> getSettings() { + return List.of(SkillSettings.PPL_EXECUTION_ENABLED); + } } diff --git a/src/main/java/org/opensearch/agent/common/SkillSettings.java b/src/main/java/org/opensearch/agent/common/SkillSettings.java new file mode 100644 index 00000000..55808748 --- /dev/null +++ b/src/main/java/org/opensearch/agent/common/SkillSettings.java @@ -0,0 +1,22 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.common; + +import org.opensearch.common.settings.Setting; + +/** + * Settings for skills plugin + */ +public final class SkillSettings { + + private SkillSettings() {} + + /** + * This setting controls whether PPL execution is enabled or not + */ + public static final Setting PPL_EXECUTION_ENABLED = Setting + .boolSetting("plugins.skills.ppl_execution_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); +} diff --git a/src/main/java/org/opensearch/agent/tools/PPLTool.java b/src/main/java/org/opensearch/agent/tools/PPLTool.java index 25201fe1..3f8c728b 100644 --- a/src/main/java/org/opensearch/agent/tools/PPLTool.java +++ b/src/main/java/org/opensearch/agent/tools/PPLTool.java @@ -30,9 +30,9 @@ import org.json.JSONObject; import org.opensearch.action.ActionRequest; import org.opensearch.action.admin.indices.mapping.get.GetMappingsRequest; -import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; +import org.opensearch.agent.common.SkillSettings; +import org.opensearch.agent.tools.utils.ClusterSettingHelper; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.MappingMetadata; import org.opensearch.core.action.ActionListener; @@ -46,7 +46,6 @@ import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; -import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; @@ -98,6 +97,8 @@ public class PPLTool implements Tool { private int head; + private ClusterSettingHelper clusterSettingHelper; + private static Gson gson = new Gson(); private static Map DEFAULT_PROMPT_DICT; @@ -127,12 +128,7 @@ public class PPLTool implements Tool { ALLOWED_FIELDS_TYPE.add("nested"); ALLOWED_FIELDS_TYPE.add("geo_point"); - try { - DEFAULT_PROMPT_DICT = loadDefaultPromptDict(); - } catch (IOException e) { - log.error("fail to load default prompt dict" + e.getMessage()); - DEFAULT_PROMPT_DICT = new HashMap<>(); - } + DEFAULT_PROMPT_DICT = loadDefaultPromptDict(); } public enum PPLModelType { @@ -156,6 +152,7 @@ public static PPLModelType from(String value) { public PPLTool( Client client, + ClusterSettingHelper clusterSettingHelper, String modelId, String contextPrompt, String pplModelType, @@ -167,18 +164,20 @@ public PPLTool( this.modelId = modelId; this.pplModelType = PPLModelType.from(pplModelType); if (contextPrompt.isEmpty()) { - this.contextPrompt = this.DEFAULT_PROMPT_DICT.getOrDefault(this.pplModelType.toString(), ""); + this.contextPrompt = DEFAULT_PROMPT_DICT.getOrDefault(this.pplModelType.toString(), ""); } else { this.contextPrompt = contextPrompt; } this.previousToolKey = previousToolKey; this.head = head; this.execute = execute; + this.clusterSettingHelper = clusterSettingHelper; } + @SuppressWarnings("unchecked") @Override public void run(Map parameters, ActionListener listener) { - parameters = extractFromChatParameters(parameters); + extractFromChatParameters(parameters); String indexName = getIndexNameFromParameters(parameters); if (StringUtils.isBlank(indexName)) { throw new IllegalArgumentException( @@ -197,14 +196,14 @@ public void run(Map parameters, ActionListener listener) } GetMappingsRequest getMappingsRequest = buildGetMappingRequest(indexName); - client.admin().indices().getMappings(getMappingsRequest, ActionListener.wrap(getMappingsResponse -> { + client.admin().indices().getMappings(getMappingsRequest, ActionListener.wrap(getMappingsResponse -> { Map mappings = getMappingsResponse.getMappings(); - if (mappings.size() == 0) { + if (mappings.isEmpty()) { throw new IllegalArgumentException("No matching mapping with index name: " + indexName); } String firstIndexName = (String) mappings.keySet().toArray()[0]; SearchRequest searchRequest = buildSearchRequest(firstIndexName); - client.search(searchRequest, ActionListener.wrap(searchResponse -> { + client.search(searchRequest, ActionListener.wrap(searchResponse -> { SearchHit[] searchHits = searchResponse.getHits().getHits(); String tableInfo = constructTableInfo(searchHits, mappings); String prompt = constructPrompt(tableInfo, question.strip(), indexName); @@ -216,13 +215,20 @@ public void run(Map parameters, ActionListener listener) modelId, MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build() ); - client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(mlTaskResponse -> { + client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(mlTaskResponse -> { ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlTaskResponse.getOutput(); ModelTensors modelTensors = modelTensorOutput.getMlModelOutputs().get(0); ModelTensor modelTensor = modelTensors.getMlModelTensors().get(0); Map dataAsMap = (Map) modelTensor.getDataAsMap(); String ppl = parseOutput(dataAsMap.get("response"), indexName); - if (!this.execute) { + boolean pplExecutedEnabled = clusterSettingHelper.getClusterSettings(SkillSettings.PPL_EXECUTION_ENABLED); + if (!pplExecutedEnabled || !this.execute) { + if (!pplExecutedEnabled) { + log + .debug( + "PPL execution is disabled, the query will be returned directly, to enable this, please set plugins.skills.ppl_execution_enabled to true" + ); + } Map ret = ImmutableMap.of("ppl", ppl); listener.onResponse((T) AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(ret))); return; @@ -234,7 +240,7 @@ public void run(Map parameters, ActionListener listener) .execute( PPLQueryAction.INSTANCE, transportPPLQueryRequest, - getPPLTransportActionListener(ActionListener.wrap(transportPPLQueryResponse -> { + getPPLTransportActionListener(ActionListener.wrap(transportPPLQueryResponse -> { String results = transportPPLQueryResponse.getResult(); Map returnResults = ImmutableMap.of("ppl", ppl, "executionResult", results); listener @@ -250,17 +256,15 @@ public void run(Map parameters, ActionListener listener) ); // Execute output here }, e -> { - log.info("fail to predict model: " + e); + log.error(String.format(Locale.ROOT, "fail to predict model: %s with error: %s", modelId, e.getMessage()), e); listener.onFailure(e); })); }, e -> { - log.info("fail to search: " + e); + log.error(String.format(Locale.ROOT, "fail to search model: %s with error: %s", modelId, e.getMessage()), e); listener.onFailure(e); - } - - )); + })); }, e -> { - log.info("fail to get mapping: " + e); + log.error(String.format(Locale.ROOT, "fail to get mapping of index: %s with error: %s", indexName, e.getMessage()), e); String errorMessage = e.getMessage(); if (errorMessage.contains("no such index")) { listener @@ -287,15 +291,14 @@ public String getName() { @Override public boolean validate(Map parameters) { - if (parameters == null || parameters.size() == 0) { - return false; - } - return true; + return parameters != null && !parameters.isEmpty(); } public static class Factory implements Tool.Factory { private Client client; + private ClusterSettingHelper clusterSettingHelper; + private static Factory INSTANCE; public static Factory getInstance() { @@ -311,8 +314,9 @@ public static Factory getInstance() { } } - public void init(Client client) { + public void init(Client client, ClusterSettingHelper clusterSettingHelper) { this.client = client; + this.clusterSettingHelper = clusterSettingHelper; } @Override @@ -320,12 +324,13 @@ public PPLTool create(Map map) { validatePPLToolParameters(map); return new PPLTool( client, + clusterSettingHelper, (String) map.get("model_id"), (String) map.getOrDefault("prompt", ""), (String) map.getOrDefault("model_type", ""), (String) map.getOrDefault("previous_tool_name", ""), - Integer.valueOf((String) map.getOrDefault("head", "-1")), - Boolean.valueOf((String) map.getOrDefault("execute", "true")) + NumberUtils.toInt((String) map.get("head"), -1), + Boolean.parseBoolean((String) map.getOrDefault("execute", "true")) ); } @@ -350,8 +355,7 @@ private SearchRequest buildSearchRequest(String indexName) { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.size(1).query(new MatchAllQueryBuilder()); // client; - SearchRequest request = new SearchRequest(new String[] { indexName }, searchSourceBuilder); - return request; + return new SearchRequest(new String[] { indexName }, searchSourceBuilder); } private GetMappingsRequest buildGetMappingRequest(String indexName) { @@ -421,19 +425,17 @@ private String constructTableInfo(SearchHit[] searchHits, Map indexInfo = ImmutableMap.of("mappingInfo", tableInfo, "question", question, "indexName", indexName); StringSubstitutor substitutor = new StringSubstitutor(indexInfo, "${indexInfo.", "}"); - String finalPrompt = substitutor.replace(contextPrompt); - return finalPrompt; + return substitutor.replace(contextPrompt); } private void extractNamesTypes(Map mappingSource, Map fieldsToType, String prefix) { - if (prefix.length() > 0) { + if (!prefix.isEmpty()) { prefix += "."; } @@ -456,7 +458,7 @@ private void extractNamesTypes(Map mappingSource, Map sampleSource, Map fieldsToSample, String prefix) throws PrivilegedActionException { - if (prefix.length() > 0) { + if (!prefix.isEmpty()) { prefix += "."; } @@ -479,16 +481,17 @@ private ActionListener getPPLTransportActionListen return ActionListener.wrap(r -> { listener.onResponse(TransportPPLQueryResponse.fromActionResponse(r)); }, listener::onFailure); } - private Map extractFromChatParameters(Map parameters) { + @SuppressWarnings("unchecked") + private void extractFromChatParameters(Map parameters) { if (parameters.containsKey("input")) { + String input = parameters.get("input"); try { - Map chatParameters = gson.fromJson(parameters.get("input"), Map.class); + Map chatParameters = gson.fromJson(input, Map.class); parameters.putAll(chatParameters); - } finally { - return parameters; + } catch (Exception e) { + log.error(String.format(Locale.ROOT, "Failed to parse chat parameters, input is: %s, which is not a valid json", input), e); } } - return parameters; } private String parseOutput(String llmOutput, String indexName) { @@ -552,14 +555,16 @@ private String getIndexNameFromParameters(Map parameters) { return indexName.trim(); } - private static Map loadDefaultPromptDict() throws IOException { - InputStream searchResponseIns = PPLTool.class.getResourceAsStream("PPLDefaultPrompt.json"); - if (searchResponseIns != null) { - String defaultPromptContent = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8); - Map defaultPromptDict = gson.fromJson(defaultPromptContent, Map.class); - return defaultPromptDict; + @SuppressWarnings("unchecked") + private static Map loadDefaultPromptDict() { + try (InputStream searchResponseIns = PPLTool.class.getResourceAsStream("PPLDefaultPrompt.json")) { + if (searchResponseIns != null) { + String defaultPromptContent = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8); + return gson.fromJson(defaultPromptContent, Map.class); + } + } catch (IOException e) { + log.error("Failed to load default prompt dict", e); } return new HashMap<>(); } - } diff --git a/src/main/java/org/opensearch/agent/tools/utils/ClusterSettingHelper.java b/src/main/java/org/opensearch/agent/tools/utils/ClusterSettingHelper.java new file mode 100644 index 00000000..92bf9dcd --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/utils/ClusterSettingHelper.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools.utils; + +import java.util.Optional; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; + +import lombok.AllArgsConstructor; + +/** + * This class is to encapsulate the {@link Settings} and {@link ClusterService} and provide a general method to retrieve dynamical cluster settings conveniently. + */ +@AllArgsConstructor +public class ClusterSettingHelper { + + private Settings settings; + + private ClusterService clusterService; + + /** + * Retrieves the cluster settings for the specified setting. + * + * @param setting the setting to retrieve cluster settings for + * @return the cluster setting value, or the default setting value if not found + */ + public T getClusterSettings(Setting setting) { + return Optional.ofNullable(clusterService.getClusterSettings().get(setting)).orElse(setting.get(settings)); + } +} diff --git a/src/test/java/org/opensearch/agent/tools/PPLToolTests.java b/src/test/java/org/opensearch/agent/tools/PPLToolTests.java index 25fe62a9..8e2c3aaa 100644 --- a/src/test/java/org/opensearch/agent/tools/PPLToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/PPLToolTests.java @@ -9,6 +9,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.utils.StringUtils.gson; @@ -16,6 +17,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.Set; import org.apache.lucene.search.TotalHits; import org.junit.Before; @@ -24,10 +26,15 @@ import org.mockito.MockitoAnnotations; import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; import org.opensearch.action.search.SearchResponse; +import org.opensearch.agent.common.SkillSettings; +import org.opensearch.agent.tools.utils.ClusterSettingHelper; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.IndicesAdminClient; import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.common.bytes.BytesReference; @@ -122,7 +129,12 @@ public void setup() { return null; }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); - PPLTool.Factory.getInstance().init(client); + Settings settings = Settings.builder().put(SkillSettings.PPL_EXECUTION_ENABLED.getKey(), true).build(); + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getSettings()).thenReturn(settings); + when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(SkillSettings.PPL_EXECUTION_ENABLED))); + ClusterSettingHelper clusterSettingHelper = new ClusterSettingHelper(settings, clusterService); + PPLTool.Factory.getInstance().init(client, clusterSettingHelper); } @Test @@ -401,6 +413,26 @@ public void testTool_executePPLFailure() { ); } + @Test + public void test_pplTool_whenPPLExecutionDisabled_returnOnlyContainsPPL() { + Settings settings = Settings.builder().put(SkillSettings.PPL_EXECUTION_ENABLED.getKey(), false).build(); + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getSettings()).thenReturn(settings); + when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(SkillSettings.PPL_EXECUTION_ENABLED))); + ClusterSettingHelper clusterSettingHelper = new ClusterSettingHelper(settings, clusterService); + PPLTool.Factory.getInstance().init(client, clusterSettingHelper); + PPLTool tool = PPLTool.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt", "head", "100")); + assertEquals(PPLTool.TYPE, tool.getName()); + + tool.run(ImmutableMap.of("index", "demo", "question", "demo"), ActionListener.wrap(executePPLResult -> { + Map returnResults = gson.fromJson(executePPLResult, Map.class); + assertNull(returnResults.get("executionResult")); + assertEquals("source=demo| head 1", returnResults.get("ppl")); + }, log::error)); + } + private void createMappings() { indexMappings = new HashMap<>(); indexMappings diff --git a/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java b/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java index 658a3fc7..853a2974 100644 --- a/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java +++ b/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java @@ -63,6 +63,7 @@ public void updateClusterSettings() { updateClusterSettings("plugins.ml_commons.jvm_heap_memory_threshold", 100); updateClusterSettings("plugins.ml_commons.allow_registering_model_via_url", true); updateClusterSettings("plugins.ml_commons.agent_framework_enabled", true); + updateClusterSettings("plugins.skills.ppl_execution_enabled", true); } @SneakyThrows diff --git a/src/test/java/org/opensearch/integTest/PPLToolIT.java b/src/test/java/org/opensearch/integTest/PPLToolIT.java index d25c6a95..b208e1f2 100644 --- a/src/test/java/org/opensearch/integTest/PPLToolIT.java +++ b/src/test/java/org/opensearch/integTest/PPLToolIT.java @@ -22,8 +22,6 @@ @Log4j2 public class PPLToolIT extends ToolIntegrationTest { - private String TEST_INDEX_NAME = "employee"; - @Override List promptHandlers() { PromptHandler PPLHandler = new PromptHandler() { @@ -60,6 +58,14 @@ public void testPPLTool() { ); } + public void test_PPLTool_whenPPLExecutionDisabled_ResultOnlyContainsPPL() { + updateClusterSettings("plugins.skills.ppl_execution_enabled", false); + prepareIndex(); + String agentId = registerAgent(); + String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"correct\", \"index\": \"employee\"}}"); + assertEquals("{\"ppl\":\"source\\u003demployee| where age \\u003e 56 | stats COUNT() as cnt\"}", result); + } + public void testPPLTool_withWrongPPLGenerated_thenThrowException() { prepareIndex(); String agentId = registerAgent(); @@ -148,8 +154,7 @@ private String registerAgent() { ) ); registerAgentRequestBody = registerAgentRequestBody.replace("", modelId); - String agentId = createAgent(registerAgentRequestBody); - return agentId; + return createAgent(registerAgentRequestBody); } @SneakyThrows @@ -166,14 +171,14 @@ private String registerAgentWithWrongModelId() { ) ); registerAgentRequestBody = registerAgentRequestBody.replace("", "wrong_model_id"); - String agentId = createAgent(registerAgentRequestBody); - return agentId; + return createAgent(registerAgentRequestBody); } @SneakyThrows private void prepareIndex() { + String testIndexName = "employee"; createIndexWithConfiguration( - TEST_INDEX_NAME, + testIndexName, "{\n" + " \"mappings\": {\n" + " \"properties\": {\n" @@ -187,8 +192,8 @@ private void prepareIndex() { + " }\n" + "}" ); - addDocToIndex(TEST_INDEX_NAME, "0", List.of("age", "name"), List.of(56, "john")); - addDocToIndex(TEST_INDEX_NAME, "1", List.of("age", "name"), List.of(56, "smith")); + addDocToIndex(testIndexName, "0", List.of("age", "name"), List.of(56, "john")); + addDocToIndex(testIndexName, "1", List.of("age", "name"), List.of(56, "smith")); } }