diff --git a/src/main/java/org/opensearch/agent/tools/RCATool.java b/src/main/java/org/opensearch/agent/tools/RCATool.java index fc0e3b80..95fb4d52 100644 --- a/src/main/java/org/opensearch/agent/tools/RCATool.java +++ b/src/main/java/org/opensearch/agent/tools/RCATool.java @@ -7,19 +7,18 @@ import static org.apache.commons.text.StringEscapeUtils.unescapeJson; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; -import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; +import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.text.StringSubstitutor; import org.apache.logging.log4j.util.Strings; import org.opensearch.action.ActionRequest; import org.opensearch.action.admin.cluster.allocation.ClusterAllocationExplainRequest; import org.opensearch.action.admin.cluster.allocation.ClusterAllocationExplainResponse; +import org.opensearch.action.support.GroupedActionListener; import org.opensearch.agent.tools.utils.ClusterStatsUtil; import org.opensearch.client.Client; import org.opensearch.cluster.routing.allocation.NodeAllocationResult; @@ -98,52 +97,44 @@ public void runOption1(Map parameters, ActionListener lis knowledge = unescapeJson(knowledge); Map knowledgeBase = StringUtils.gson.fromJson(knowledge, Map.class); List> causes = (List>) knowledgeBase.get("causes"); - Set apis = causes.stream().map(c -> c.get(API_URL_FIELD)).collect(Collectors.toSet()); - ActionListener> apiListener = new ActionListener<>() { - @Override - public void onResponse(Map apiToResponse) { - Map LLMParams = new java.util.HashMap<>( - Map - .of( - "phenomenon", - (String) knowledgeBase.get("phenomenon"), - "causes", - StringUtils.gson.toJson(causes), - "responses", - StringUtils.gson.toJson(apiToResponse) - ) - ); - StringSubstitutor substitute = new StringSubstitutor(LLMParams, "${parameters.", "}"); - String finalToolPrompt = substitute.replace(TOOL_PROMPT); - LLMParams.put("prompt", finalToolPrompt); - RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(LLMParams).build(); - ActionRequest request = new MLPredictionTaskRequest( - modelId, - MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build() - ); - client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(response -> { - ModelTensorOutput modelTensorOutput = (ModelTensorOutput) response.getOutput(); - Map dataMap = Optional - .ofNullable(modelTensorOutput.getMlModelOutputs()) - .flatMap(outputs -> outputs.stream().findFirst()) - .flatMap(modelTensors -> modelTensors.getMlModelTensors().stream().findFirst()) - .map(ModelTensor::getDataAsMap) - .orElse(null); - if (dataMap == null) { - throw new IllegalArgumentException("No dataMap returned from LLM."); - } - listener.onResponse((T) dataMap.get("completion")); - }, listener::onFailure)); - } - - @Override - public void onFailure(Exception e) { - listener.onFailure(e); - } - }; + List apiList = causes.stream().map(cause -> cause.get(API_URL_FIELD)).distinct().collect(Collectors.toList()); + final GroupedActionListener> groupedListener = new GroupedActionListener<>(ActionListener.wrap(responses -> { + Map apiToResponse = responses.stream().collect(Collectors.toMap(Pair::getKey, Pair::getValue)); + Map LLMParams = new java.util.HashMap<>( + Map + .of( + "phenomenon", + (String) knowledgeBase.get("phenomenon"), + "causes", + StringUtils.gson.toJson(causes), + "responses", + StringUtils.gson.toJson(apiToResponse) + ) + ); + StringSubstitutor substitute = new StringSubstitutor(LLMParams, "${parameters.", "}"); + String finalToolPrompt = substitute.replace(TOOL_PROMPT); + LLMParams.put("prompt", finalToolPrompt); + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(LLMParams).build(); + ActionRequest request = new MLPredictionTaskRequest( + modelId, + MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build() + ); + client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(response -> { + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) response.getOutput(); + Map dataMap = Optional + .ofNullable(modelTensorOutput.getMlModelOutputs()) + .flatMap(outputs -> outputs.stream().findFirst()) + .flatMap(modelTensors -> modelTensors.getMlModelTensors().stream().findFirst()) + .map(ModelTensor::getDataAsMap) + .orElse(null); + if (dataMap == null) { + throw new IllegalArgumentException("No dataMap returned from LLM."); + } + listener.onResponse((T) dataMap.get("completion")); + }, listener::onFailure)); + }, listener::onFailure), apiList.size()); // TODO: support different parameters for different apis - Map> apiToParameters = apis.stream().collect(Collectors.toMap(api -> api, api -> parameters)); - invokeAPIs(apis, apiToParameters, apiListener); + apiList.forEach(api -> invokeAPI(api, parameters, groupedListener)); } /** @@ -164,44 +155,7 @@ public void run(Map parameters, ActionListener listener) } } - private void invokeAPIs(Set urls, Map> parameters, ActionListener> listener) { - Map> apiFutures = new HashMap<>(); - for (String url : urls) { - Map parameter = parameters.get(url); - CompletableFuture apiFuture = new CompletableFuture<>(); - apiFutures.put(url, apiFuture); - - ActionListener apiListener = new ActionListener<>() { - @Override - public void onResponse(String response) { - apiFuture.complete(response); - } - - @Override - public void onFailure(Exception e) { - apiFuture.completeExceptionally(e); - listener.onFailure(e); - } - }; - - invokeAPI(url, parameter, apiListener); - } - - try { - CompletableFuture> mapFuture = CompletableFuture - .allOf(apiFutures.values().toArray(new CompletableFuture[0])) - .thenApply( - v -> apiFutures.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().join())) - ); - Map apiToResponse = mapFuture.join(); - listener.onResponse(apiToResponse); - } catch (Exception e) { - log.error("Failed to get all api results from rca tool", e); - listener.onFailure(e); - } - } - - private void invokeAPI(String url, Map parameters, ActionListener listener) { + private void invokeAPI(String url, Map parameters, GroupedActionListener> groupedListener) { // TODO: add other API urls switch (url) { case "_cluster/allocation/explain": @@ -220,12 +174,12 @@ public void onResponse(ClusterAllocationExplainResponse allocationExplainRespons stringBuilder.append(decision.getExplanation()); } } - listener.onResponse(stringBuilder.toString()); + groupedListener.onResponse(Pair.of("_cluster/allocation/explain", stringBuilder.toString())); } @Override public void onFailure(Exception e) { - listener.onFailure(e); + groupedListener.onFailure(e); } }; @@ -239,7 +193,7 @@ public void onFailure(Exception e) { break; default: Exception exception = new IllegalArgumentException("API not supported"); - listener.onFailure(exception); + groupedListener.onFailure(exception); } }