Skip to content

Commit

Permalink
Add cluster setting to control ppl execution (#344)
Browse files Browse the repository at this point in the history
* Add cluster setting to control ppl execution

Signed-off-by: zane-neo <[email protected]>

* format code

Signed-off-by: zane-neo <[email protected]>

* format code

Signed-off-by: zane-neo <[email protected]>

* Add debug log to indicate the ppl execution settings

Signed-off-by: zane-neo <[email protected]>

* format code

Signed-off-by: zane-neo <[email protected]>

---------

Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo authored Jul 10, 2024
1 parent 00e9466 commit 14d9ef2
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 62 deletions.
14 changes: 12 additions & 2 deletions src/main/java/org/opensearch/agent/ToolPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.List;
import java.util.function.Supplier;

import org.opensearch.agent.common.SkillSettings;
import org.opensearch.agent.tools.NeuralSparseSearchTool;
import org.opensearch.agent.tools.PPLTool;
import org.opensearch.agent.tools.RAGTool;
Expand All @@ -18,9 +19,12 @@
import org.opensearch.agent.tools.SearchAnomalyResultsTool;
import org.opensearch.agent.tools.SearchMonitorsTool;
import org.opensearch.agent.tools.VectorDBTool;
import org.opensearch.agent.tools.utils.ClusterSettingHelper;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.env.Environment;
Expand Down Expand Up @@ -59,8 +63,9 @@ public Collection<Object> createComponents(
this.client = client;
this.clusterService = clusterService;
this.xContentRegistry = xContentRegistry;

PPLTool.Factory.getInstance().init(client);
Settings settings = environment.settings();
ClusterSettingHelper clusterSettingHelper = new ClusterSettingHelper(settings, clusterService);
PPLTool.Factory.getInstance().init(client, clusterSettingHelper);
NeuralSparseSearchTool.Factory.getInstance().init(client, xContentRegistry);
VectorDBTool.Factory.getInstance().init(client, xContentRegistry);
RAGTool.Factory.getInstance().init(client, xContentRegistry);
Expand All @@ -85,4 +90,9 @@ public List<Tool.Factory<? extends Tool>> getToolFactories() {
SearchMonitorsTool.Factory.getInstance()
);
}

@Override
public List<Setting<?>> getSettings() {
return List.of(SkillSettings.PPL_EXECUTION_ENABLED);
}
}
22 changes: 22 additions & 0 deletions src/main/java/org/opensearch/agent/common/SkillSettings.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.agent.common;

import org.opensearch.common.settings.Setting;

/**
* Settings for skills plugin
*/
public final class SkillSettings {

private SkillSettings() {}

/**
* This setting controls whether PPL execution is enabled or not
*/
public static final Setting<Boolean> PPL_EXECUTION_ENABLED = Setting
.boolSetting("plugins.skills.ppl_execution_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
}
105 changes: 55 additions & 50 deletions src/main/java/org/opensearch/agent/tools/PPLTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
import org.json.JSONObject;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.admin.indices.mapping.get.GetMappingsRequest;
import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.agent.common.SkillSettings;
import org.opensearch.agent.tools.utils.ClusterSettingHelper;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.MappingMetadata;
import org.opensearch.core.action.ActionListener;
Expand All @@ -46,7 +46,6 @@
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.search.SearchHit;
Expand Down Expand Up @@ -98,6 +97,8 @@ public class PPLTool implements Tool {

private int head;

private ClusterSettingHelper clusterSettingHelper;

private static Gson gson = new Gson();

private static Map<String, String> DEFAULT_PROMPT_DICT;
Expand Down Expand Up @@ -127,12 +128,7 @@ public class PPLTool implements Tool {
ALLOWED_FIELDS_TYPE.add("nested");
ALLOWED_FIELDS_TYPE.add("geo_point");

try {
DEFAULT_PROMPT_DICT = loadDefaultPromptDict();
} catch (IOException e) {
log.error("fail to load default prompt dict" + e.getMessage());
DEFAULT_PROMPT_DICT = new HashMap<>();
}
DEFAULT_PROMPT_DICT = loadDefaultPromptDict();
}

public enum PPLModelType {
Expand All @@ -156,6 +152,7 @@ public static PPLModelType from(String value) {

public PPLTool(
Client client,
ClusterSettingHelper clusterSettingHelper,
String modelId,
String contextPrompt,
String pplModelType,
Expand All @@ -167,18 +164,20 @@ public PPLTool(
this.modelId = modelId;
this.pplModelType = PPLModelType.from(pplModelType);
if (contextPrompt.isEmpty()) {
this.contextPrompt = this.DEFAULT_PROMPT_DICT.getOrDefault(this.pplModelType.toString(), "");
this.contextPrompt = DEFAULT_PROMPT_DICT.getOrDefault(this.pplModelType.toString(), "");
} else {
this.contextPrompt = contextPrompt;
}
this.previousToolKey = previousToolKey;
this.head = head;
this.execute = execute;
this.clusterSettingHelper = clusterSettingHelper;
}

@SuppressWarnings("unchecked")
@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
parameters = extractFromChatParameters(parameters);
extractFromChatParameters(parameters);
String indexName = getIndexNameFromParameters(parameters);
if (StringUtils.isBlank(indexName)) {
throw new IllegalArgumentException(
Expand All @@ -197,14 +196,14 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
}

GetMappingsRequest getMappingsRequest = buildGetMappingRequest(indexName);
client.admin().indices().getMappings(getMappingsRequest, ActionListener.<GetMappingsResponse>wrap(getMappingsResponse -> {
client.admin().indices().getMappings(getMappingsRequest, ActionListener.wrap(getMappingsResponse -> {
Map<String, MappingMetadata> mappings = getMappingsResponse.getMappings();
if (mappings.size() == 0) {
if (mappings.isEmpty()) {
throw new IllegalArgumentException("No matching mapping with index name: " + indexName);
}
String firstIndexName = (String) mappings.keySet().toArray()[0];
SearchRequest searchRequest = buildSearchRequest(firstIndexName);
client.search(searchRequest, ActionListener.<SearchResponse>wrap(searchResponse -> {
client.search(searchRequest, ActionListener.wrap(searchResponse -> {
SearchHit[] searchHits = searchResponse.getHits().getHits();
String tableInfo = constructTableInfo(searchHits, mappings);
String prompt = constructPrompt(tableInfo, question.strip(), indexName);
Expand All @@ -217,7 +216,7 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(),
null
);
client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.<MLTaskResponse>wrap(mlTaskResponse -> {
client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(mlTaskResponse -> {
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlTaskResponse.getOutput();
ModelTensors modelTensors = modelTensorOutput.getMlModelOutputs().get(0);
ModelTensor modelTensor = modelTensors.getMlModelTensors().get(0);
Expand All @@ -227,7 +226,14 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
return;
}
String ppl = parseOutput(dataAsMap.get("response"), indexName);
if (!this.execute) {
boolean pplExecutedEnabled = clusterSettingHelper.getClusterSettings(SkillSettings.PPL_EXECUTION_ENABLED);
if (!pplExecutedEnabled || !this.execute) {
if (!pplExecutedEnabled) {
log
.debug(
"PPL execution is disabled, the query will be returned directly, to enable this, please set plugins.skills.ppl_execution_enabled to true"
);
}
Map<String, String> ret = ImmutableMap.of("ppl", ppl);
listener.onResponse((T) AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(ret)));
return;
Expand All @@ -239,7 +245,7 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
.execute(
PPLQueryAction.INSTANCE,
transportPPLQueryRequest,
getPPLTransportActionListener(ActionListener.<TransportPPLQueryResponse>wrap(transportPPLQueryResponse -> {
getPPLTransportActionListener(ActionListener.wrap(transportPPLQueryResponse -> {
String results = transportPPLQueryResponse.getResult();
Map<String, String> returnResults = ImmutableMap.of("ppl", ppl, "executionResult", results);
listener
Expand All @@ -255,17 +261,15 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
);
// Execute output here
}, e -> {
log.info("fail to predict model: " + e);
log.error(String.format(Locale.ROOT, "fail to predict model: %s with error: %s", modelId, e.getMessage()), e);
listener.onFailure(e);
}));
}, e -> {
log.info("fail to search: " + e);
log.error(String.format(Locale.ROOT, "fail to search model: %s with error: %s", modelId, e.getMessage()), e);
listener.onFailure(e);
}

));
}));
}, e -> {
log.info("fail to get mapping: " + e);
log.error(String.format(Locale.ROOT, "fail to get mapping of index: %s with error: %s", indexName, e.getMessage()), e);
String errorMessage = e.getMessage();
if (errorMessage.contains("no such index")) {
listener
Expand All @@ -292,15 +296,14 @@ public String getName() {

@Override
public boolean validate(Map<String, String> parameters) {
if (parameters == null || parameters.size() == 0) {
return false;
}
return true;
return parameters != null && !parameters.isEmpty();
}

public static class Factory implements Tool.Factory<PPLTool> {
private Client client;

private ClusterSettingHelper clusterSettingHelper;

private static Factory INSTANCE;

public static Factory getInstance() {
Expand All @@ -316,21 +319,23 @@ public static Factory getInstance() {
}
}

public void init(Client client) {
public void init(Client client, ClusterSettingHelper clusterSettingHelper) {
this.client = client;
this.clusterSettingHelper = clusterSettingHelper;
}

@Override
public PPLTool create(Map<String, Object> map) {
validatePPLToolParameters(map);
return new PPLTool(
client,
clusterSettingHelper,
(String) map.get("model_id"),
(String) map.getOrDefault("prompt", ""),
(String) map.getOrDefault("model_type", ""),
(String) map.getOrDefault("previous_tool_name", ""),
Integer.valueOf((String) map.getOrDefault("head", "-1")),
Boolean.valueOf((String) map.getOrDefault("execute", "true"))
NumberUtils.toInt((String) map.get("head"), -1),
Boolean.parseBoolean((String) map.getOrDefault("execute", "true"))
);
}

Expand All @@ -355,8 +360,7 @@ private SearchRequest buildSearchRequest(String indexName) {
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.size(1).query(new MatchAllQueryBuilder());
// client;
SearchRequest request = new SearchRequest(new String[] { indexName }, searchSourceBuilder);
return request;
return new SearchRequest(new String[] { indexName }, searchSourceBuilder);
}

private GetMappingsRequest buildGetMappingRequest(String indexName) {
Expand Down Expand Up @@ -426,19 +430,17 @@ private String constructTableInfo(SearchHit[] searchHits, Map<String, MappingMet
}
}

String tableInfo = tableInfoJoiner.toString();
return tableInfo;
return tableInfoJoiner.toString();
}

private String constructPrompt(String tableInfo, String question, String indexName) {
Map<String, String> indexInfo = ImmutableMap.of("mappingInfo", tableInfo, "question", question, "indexName", indexName);
StringSubstitutor substitutor = new StringSubstitutor(indexInfo, "${indexInfo.", "}");
String finalPrompt = substitutor.replace(contextPrompt);
return finalPrompt;
return substitutor.replace(contextPrompt);
}

private void extractNamesTypes(Map<String, Object> mappingSource, Map<String, String> fieldsToType, String prefix) {
if (prefix.length() > 0) {
if (!prefix.isEmpty()) {
prefix += ".";
}

Expand All @@ -461,7 +463,7 @@ private void extractNamesTypes(Map<String, Object> mappingSource, Map<String, St

private static void extractSamples(Map<String, Object> sampleSource, Map<String, String> fieldsToSample, String prefix)
throws PrivilegedActionException {
if (prefix.length() > 0) {
if (!prefix.isEmpty()) {
prefix += ".";
}

Expand All @@ -484,16 +486,17 @@ private <T extends ActionResponse> ActionListener<T> getPPLTransportActionListen
return ActionListener.wrap(r -> { listener.onResponse(TransportPPLQueryResponse.fromActionResponse(r)); }, listener::onFailure);
}

private Map<String, String> extractFromChatParameters(Map<String, String> parameters) {
@SuppressWarnings("unchecked")
private void extractFromChatParameters(Map<String, String> parameters) {
if (parameters.containsKey("input")) {
String input = parameters.get("input");
try {
Map<String, String> chatParameters = gson.fromJson(parameters.get("input"), Map.class);
Map<String, String> chatParameters = gson.fromJson(input, Map.class);
parameters.putAll(chatParameters);
} finally {
return parameters;
} catch (Exception e) {
log.error(String.format(Locale.ROOT, "Failed to parse chat parameters, input is: %s, which is not a valid json", input), e);
}
}
return parameters;
}

private String parseOutput(String llmOutput, String indexName) {
Expand Down Expand Up @@ -557,14 +560,16 @@ private String getIndexNameFromParameters(Map<String, String> parameters) {
return indexName.trim();
}

private static Map<String, String> loadDefaultPromptDict() throws IOException {
InputStream searchResponseIns = PPLTool.class.getResourceAsStream("PPLDefaultPrompt.json");
if (searchResponseIns != null) {
String defaultPromptContent = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8);
Map<String, String> defaultPromptDict = gson.fromJson(defaultPromptContent, Map.class);
return defaultPromptDict;
@SuppressWarnings("unchecked")
private static Map<String, String> loadDefaultPromptDict() {
try (InputStream searchResponseIns = PPLTool.class.getResourceAsStream("PPLDefaultPrompt.json")) {
if (searchResponseIns != null) {
String defaultPromptContent = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8);
return gson.fromJson(defaultPromptContent, Map.class);
}
} catch (IOException e) {
log.error("Failed to load default prompt dict", e);
}
return new HashMap<>();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.agent.tools.utils;

import java.util.Optional;

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;

import lombok.AllArgsConstructor;

/**
* This class is to encapsulate the {@link Settings} and {@link ClusterService} and provide a general method to retrieve dynamical cluster settings conveniently.
*/
@AllArgsConstructor
public class ClusterSettingHelper {

private Settings settings;

private ClusterService clusterService;

/**
* Retrieves the cluster settings for the specified setting.
*
* @param setting the setting to retrieve cluster settings for
* @return the cluster setting value, or the default setting value if not found
*/
public <T> T getClusterSettings(Setting<T> setting) {
return Optional.ofNullable(clusterService.getClusterSettings().get(setting)).orElse(setting.get(settings));
}
}
Loading

0 comments on commit 14d9ef2

Please sign in to comment.