Skip to content

Commit

Permalink
pick heng's code
Browse files Browse the repository at this point in the history
Signed-off-by: yuye-aws <[email protected]>
  • Loading branch information
yuye-aws committed Aug 5, 2024
1 parent 769df71 commit 9a869e6
Showing 1 changed file with 43 additions and 89 deletions.
132 changes: 43 additions & 89 deletions src/main/java/org/opensearch/agent/tools/RCATool.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,18 @@

import static org.apache.commons.text.StringEscapeUtils.unescapeJson;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;

import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.text.StringSubstitutor;
import org.apache.logging.log4j.util.Strings;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.admin.cluster.allocation.ClusterAllocationExplainRequest;
import org.opensearch.action.admin.cluster.allocation.ClusterAllocationExplainResponse;
import org.opensearch.action.support.GroupedActionListener;
import org.opensearch.agent.tools.utils.ClusterStatsUtil;
import org.opensearch.client.Client;
import org.opensearch.cluster.routing.allocation.NodeAllocationResult;
Expand Down Expand Up @@ -98,52 +97,44 @@ public <T> void runOption1(Map<String, String> parameters, ActionListener<T> lis
knowledge = unescapeJson(knowledge);
Map<String, ?> knowledgeBase = StringUtils.gson.fromJson(knowledge, Map.class);
List<Map<String, String>> causes = (List<Map<String, String>>) knowledgeBase.get("causes");
Set<String> apis = causes.stream().map(c -> c.get(API_URL_FIELD)).collect(Collectors.toSet());
ActionListener<Map<String, String>> apiListener = new ActionListener<>() {
@Override
public void onResponse(Map<String, String> apiToResponse) {
Map<String, String> LLMParams = new java.util.HashMap<>(
Map
.of(
"phenomenon",
(String) knowledgeBase.get("phenomenon"),
"causes",
StringUtils.gson.toJson(causes),
"responses",
StringUtils.gson.toJson(apiToResponse)
)
);
StringSubstitutor substitute = new StringSubstitutor(LLMParams, "${parameters.", "}");
String finalToolPrompt = substitute.replace(TOOL_PROMPT);
LLMParams.put("prompt", finalToolPrompt);
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(LLMParams).build();
ActionRequest request = new MLPredictionTaskRequest(
modelId,
MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()
);
client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(response -> {
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) response.getOutput();
Map<String, ?> dataMap = Optional
.ofNullable(modelTensorOutput.getMlModelOutputs())
.flatMap(outputs -> outputs.stream().findFirst())
.flatMap(modelTensors -> modelTensors.getMlModelTensors().stream().findFirst())
.map(ModelTensor::getDataAsMap)
.orElse(null);
if (dataMap == null) {
throw new IllegalArgumentException("No dataMap returned from LLM.");
}
listener.onResponse((T) dataMap.get("completion"));
}, listener::onFailure));
}

@Override
public void onFailure(Exception e) {
listener.onFailure(e);
}
};
List<String> apiList = causes.stream().map(cause -> cause.get(API_URL_FIELD)).distinct().collect(Collectors.toList());
final GroupedActionListener<Pair<String, String>> groupedListener = new GroupedActionListener<>(ActionListener.wrap(responses -> {
Map<String, String> apiToResponse = responses.stream().collect(Collectors.toMap(Pair::getKey, Pair::getValue));
Map<String, String> LLMParams = new java.util.HashMap<>(
Map
.of(
"phenomenon",
(String) knowledgeBase.get("phenomenon"),
"causes",
StringUtils.gson.toJson(causes),
"responses",
StringUtils.gson.toJson(apiToResponse)
)
);
StringSubstitutor substitute = new StringSubstitutor(LLMParams, "${parameters.", "}");
String finalToolPrompt = substitute.replace(TOOL_PROMPT);
LLMParams.put("prompt", finalToolPrompt);
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(LLMParams).build();
ActionRequest request = new MLPredictionTaskRequest(
modelId,
MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()
);
client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(response -> {
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) response.getOutput();
Map<String, ?> dataMap = Optional
.ofNullable(modelTensorOutput.getMlModelOutputs())
.flatMap(outputs -> outputs.stream().findFirst())
.flatMap(modelTensors -> modelTensors.getMlModelTensors().stream().findFirst())
.map(ModelTensor::getDataAsMap)
.orElse(null);
if (dataMap == null) {
throw new IllegalArgumentException("No dataMap returned from LLM.");
}
listener.onResponse((T) dataMap.get("completion"));
}, listener::onFailure));
}, listener::onFailure), apiList.size());
// TODO: support different parameters for different apis
Map<String, Map<String, String>> apiToParameters = apis.stream().collect(Collectors.toMap(api -> api, api -> parameters));
invokeAPIs(apis, apiToParameters, apiListener);
apiList.forEach(api -> invokeAPI(api, parameters, groupedListener));
}

/**
Expand All @@ -164,44 +155,7 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
}
}

private void invokeAPIs(Set<String> urls, Map<String, Map<String, String>> parameters, ActionListener<Map<String, String>> listener) {
Map<String, CompletableFuture<String>> apiFutures = new HashMap<>();
for (String url : urls) {
Map<String, String> parameter = parameters.get(url);
CompletableFuture<String> apiFuture = new CompletableFuture<>();
apiFutures.put(url, apiFuture);

ActionListener<String> apiListener = new ActionListener<>() {
@Override
public void onResponse(String response) {
apiFuture.complete(response);
}

@Override
public void onFailure(Exception e) {
apiFuture.completeExceptionally(e);
listener.onFailure(e);
}
};

invokeAPI(url, parameter, apiListener);
}

try {
CompletableFuture<Map<String, String>> mapFuture = CompletableFuture
.allOf(apiFutures.values().toArray(new CompletableFuture<?>[0]))
.thenApply(
v -> apiFutures.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().join()))
);
Map<String, String> apiToResponse = mapFuture.join();
listener.onResponse(apiToResponse);
} catch (Exception e) {
log.error("Failed to get all api results from rca tool", e);
listener.onFailure(e);
}
}

private void invokeAPI(String url, Map<String, String> parameters, ActionListener<String> listener) {
private void invokeAPI(String url, Map<String, String> parameters, GroupedActionListener<Pair<String, String>> groupedListener) {
// TODO: add other API urls
switch (url) {
case "_cluster/allocation/explain":
Expand All @@ -220,12 +174,12 @@ public void onResponse(ClusterAllocationExplainResponse allocationExplainRespons
stringBuilder.append(decision.getExplanation());
}
}
listener.onResponse(stringBuilder.toString());
groupedListener.onResponse(Pair.of("_cluster/allocation/explain", stringBuilder.toString()));
}

@Override
public void onFailure(Exception e) {
listener.onFailure(e);
groupedListener.onFailure(e);
}
};

Expand All @@ -239,7 +193,7 @@ public void onFailure(Exception e) {
break;
default:
Exception exception = new IllegalArgumentException("API not supported");
listener.onFailure(exception);
groupedListener.onFailure(exception);
}
}

Expand Down

0 comments on commit 9a869e6

Please sign in to comment.