diff --git a/src/main/java/org/opensearch/agent/tools/RCATool.java b/src/main/java/org/opensearch/agent/tools/RCATool.java index 6bc286fd..955ccd6a 100644 --- a/src/main/java/org/opensearch/agent/tools/RCATool.java +++ b/src/main/java/org/opensearch/agent/tools/RCATool.java @@ -117,22 +117,26 @@ public void runOption1(Map knowledgeBase, ActionListener liste @SuppressWarnings("unchecked") public void runOption2(Map knowledgeBase, ActionListener listener) { - // input phenomenon's embedded dense vector String phenomenon = (String) knowledgeBase.get("phenomenon"); - List inputVector = getEmbeddedVector(Collections.singletonList(phenomenon)); - // api response embedded dense vectors + // API response embedded vectors List> causes = (List>) knowledgeBase.get("causes"); - List apiResponses = causes.stream() + List responses = causes.stream() .map(cause -> cause.get("response")) .collect(Collectors.toList()); - List rootCauseVectors = getEmbeddedVector(apiResponses); + List responseVectors = getEmbeddedVector(responses); - Map dotProductMap = IntStream.range(0, rootCauseVectors.size()) + // expected API response embedded vectors + List expectedResponses = causes.stream() + .map(cause -> cause.get("expected_response")) + .collect(Collectors.toList()); + List expectedResponseVectors = getEmbeddedVector(expectedResponses); + + Map dotProductMap = IntStream.range(0, causes.size()) .boxed() .collect(Collectors.toMap( i -> causes.get(i).get("reason"), - i -> inputVector.get(0).dotProduct(rootCauseVectors.get(i)) + i -> responseVectors.get(i).dotProduct(expectedResponseVectors.get(i)) )); Optional> mapEntry =