From bd3155ca51642c9d1545be33fe11cd6bbb52eee4 Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Mon, 22 Apr 2024 20:03:58 -0700 Subject: [PATCH] Migrate RAG pipeline to async processing. (#2345) * Migrate RAG pipeline to async processing. Signed-off-by: Austin Lee * Address reviewer comments. Signed-off-by: Austin Lee --------- Signed-off-by: Austin Lee (cherry picked from commit 4b26ebff52828fc7a0e92501eeca8c649627004a) --- .../GenerativeQAResponseProcessor.java | 147 ++++++--- .../client/ConversationalMemoryClient.java | 47 +++ .../client/MachineLearningInternalClient.java | 2 +- .../generative/llm/DefaultLlmImpl.java | 37 ++- .../questionanswering/generative/llm/Llm.java | 4 +- .../GenerativeQAResponseProcessorTests.java | 286 ++++++++++++------ .../generative/llm/DefaultLlmImplTests.java | 210 ++++++++++--- 7 files changed, 542 insertions(+), 191 deletions(-) diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java index 3a2d256695..7b1814c2a5 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java @@ -27,20 +27,23 @@ import java.util.Map; import java.util.function.BooleanSupplier; -import org.opensearch.OpenSearchException; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.ml.common.exception.MLException; import org.opensearch.search.SearchHit; import org.opensearch.search.pipeline.AbstractProcessor; +import org.opensearch.search.pipeline.PipelineProcessingContext; import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchResponseProcessor; import org.opensearch.searchpipelines.questionanswering.generative.client.ConversationalMemoryClient; import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamUtil; import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters; +import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionInput; import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput; import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm; import org.opensearch.searchpipelines.questionanswering.generative.llm.LlmIOUtil; @@ -65,8 +68,6 @@ public class GenerativeQAResponseProcessor extends AbstractProcessor implements private static final int DEFAULT_PROCESSOR_TIME_IN_SECONDS = 30; - // TODO Add "interaction_count". This is how far back in chat history we want to go back when calling LLM. - private final String llmModel; private final List contextFields; @@ -106,9 +107,19 @@ protected GenerativeQAResponseProcessor( } @Override - public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception { + public SearchResponse processResponse(SearchRequest searchRequest, SearchResponse searchResponse) { + // Synchronous call is no longer supported because this execution can occur on a transport thread. + throw new UnsupportedOperationException(); + } - log.info("Entering processResponse."); + @Override + public void processResponseAsync( + SearchRequest request, + SearchResponse response, + PipelineProcessingContext requestContext, + ActionListener responseListener + ) { + log.debug("Entering processResponse."); if (!this.featureFlagSupplier.getAsBoolean()) { throw new MLException(GenerativeQAProcessorConstants.FEATURE_NOT_ENABLED_ERROR_MSG); @@ -116,10 +127,12 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp GenerativeQAParameters params = GenerativeQAParamUtil.getGenerativeQAParameters(request); - Integer timeout = params.getTimeout(); - if (timeout == null || timeout == GenerativeQAParameters.SIZE_NULL_VALUE) { - timeout = DEFAULT_PROCESSOR_TIME_IN_SECONDS; + Integer t = params.getTimeout(); + if (t == null || t == GenerativeQAParameters.SIZE_NULL_VALUE) { + t = DEFAULT_PROCESSOR_TIME_IN_SECONDS; } + final int timeout = t; + log.debug("Timeout for this request: {} seconds.", timeout); String llmQuestion = params.getLlmQuestion(); String llmModel = params.getLlmModel() == null ? this.llmModel : params.getLlmModel(); @@ -128,14 +141,15 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp } String conversationId = params.getConversationId(); + if (conversationId != null && !Strings.hasText(conversationId)) { + throw new IllegalArgumentException("Empty conversation_id is not allowed."); + } Instant start = Instant.now(); Integer interactionSize = params.getInteractionSize(); if (interactionSize == null || interactionSize == GenerativeQAParameters.SIZE_NULL_VALUE) { interactionSize = DEFAULT_CHAT_HISTORY_WINDOW; } - List chatHistory = (conversationId == null) - ? Collections.emptyList() - : memoryClient.getInteractions(conversationId, interactionSize); + log.debug("Using interaction size of {}", interactionSize); Integer topN = params.getContextSize(); if (topN == null) { @@ -153,10 +167,32 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp effectiveUserInstructions = params.getUserInstructions(); } - start = Instant.now(); - try { - ChatCompletionOutput output = llm - .doChatCompletion( + final List chatHistory = new ArrayList<>(); + if (conversationId == null) { + doChatCompletion( + LlmIOUtil + .createChatCompletionInput( + systemPrompt, + userInstructions, + llmModel, + llmQuestion, + chatHistory, + searchResults, + timeout, + params.getLlmResponseField() + ), + null, + llmQuestion, + searchResults, + response, + responseListener + ); + } else { + final Instant memoryStart = Instant.now(); + memoryClient.getInteractions(conversationId, interactionSize, ActionListener.wrap(r -> { + log.debug("getInteractions complete. ({})", getDuration(memoryStart)); + chatHistory.addAll(r); + doChatCompletion( LlmIOUtil .createChatCompletionInput( systemPrompt, @@ -167,43 +203,70 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp searchResults, timeout, params.getLlmResponseField() - ) + ), + conversationId, + llmQuestion, + searchResults, + response, + responseListener ); - log.info("doChatCompletion complete. ({})", getDuration(start)); + }, responseListener::onFailure)); + } + } - String answer = null; - String errorMessage = null; - String interactionId = null; - if (output.isErrorOccurred()) { - errorMessage = output.getErrors().get(0); - } else { - answer = (String) output.getAnswers().get(0); + private void doChatCompletion( + ChatCompletionInput input, + String conversationId, + String llmQuestion, + List searchResults, + SearchResponse response, + ActionListener responseListener + ) { + + final Instant chatStart = Instant.now(); + llm.doChatCompletion(input, new ActionListener<>() { + @Override + public void onResponse(ChatCompletionOutput output) { + log.debug("doChatCompletion complete. ({})", getDuration(chatStart)); + + final String answer = getAnswer(output); + final String errorMessage = getError(output); if (conversationId != null) { - start = Instant.now(); - interactionId = memoryClient + final Instant memoryStart = Instant.now(); + memoryClient .createInteraction( conversationId, llmQuestion, PromptUtil.getPromptTemplate(systemPrompt, userInstructions), answer, GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, - Collections.singletonMap("metadata", jsonArrayToString(searchResults)) + Collections.singletonMap("metadata", jsonArrayToString(searchResults)), + ActionListener.wrap(r -> { + responseListener.onResponse(insertAnswer(response, answer, errorMessage, r)); + log.info("Created a new interaction: {} ({})", r, getDuration(memoryStart)); + }, responseListener::onFailure) ); - log.info("Created a new interaction: {} ({})", interactionId, getDuration(start)); + + } else { + responseListener.onResponse(insertAnswer(response, answer, errorMessage, null)); } + } - return insertAnswer(response, answer, errorMessage, interactionId); - } catch (NullPointerException nullPointerException) { - throw new IllegalArgumentException(IllegalArgumentMessage); - } catch (Exception e) { - throw new OpenSearchException("GenerativeQAResponseProcessor failed in precessing response"); - } - } + @Override + public void onFailure(Exception e) { + responseListener.onFailure(e); + } - long getDuration(Instant start) { - return Duration.between(start, Instant.now()).toMillis(); + private String getError(ChatCompletionOutput output) { + return output.isErrorOccurred() ? output.getErrors().get(0) : null; + } + + private String getAnswer(ChatCompletionOutput output) { + return output.isErrorOccurred() ? null : (String) output.getAnswers().get(0); + } + }); } @Override @@ -211,9 +274,11 @@ public String getType() { return GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE; } - private SearchResponse insertAnswer(SearchResponse response, String answer, String errorMessage, String interactionId) { + private long getDuration(Instant start) { + return Duration.between(start, Instant.now()).toMillis(); + } - // TODO return the interaction id in the response. + private SearchResponse insertAnswer(SearchResponse response, String answer, String errorMessage, String interactionId) { return new GenerativeSearchResponse( answer, @@ -240,9 +305,7 @@ private List getSearchResults(SearchResponse response, Integer topN) { for (String contextField : contextFields) { Object context = docSourceMap.get(contextField); if (context == null) { - log.error("Context " + contextField + " not found in search hit " + hits[i]); - // TODO throw a more meaningful error here? - throw new RuntimeException(); + throw new RuntimeException("Context " + contextField + " not found in search hit " + hits[i]); } searchResults.add(context.toString()); } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java index 5db677fe65..70ab33d957 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java @@ -24,6 +24,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.ml.memory.action.conversation.CreateConversationAction; @@ -83,6 +84,33 @@ public String createInteraction( return res.getId(); } + public void createInteraction( + String conversationId, + String input, + String promptTemplate, + String response, + String origin, + Map additionalInfo, + ActionListener listener + ) { + client + .execute( + CreateInteractionAction.INSTANCE, + new CreateInteractionRequest(conversationId, input, promptTemplate, response, origin, additionalInfo), + new ActionListener() { + @Override + public void onResponse(CreateInteractionResponse createInteractionResponse) { + listener.onResponse(createInteractionResponse.getId()); + } + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + } + ); + } + public List getInteractions(String conversationId, int lastN) { Preconditions.checkArgument(lastN > 0, "lastN must be at least 1."); @@ -113,4 +141,23 @@ public List getInteractions(String conversationId, int lastN) { return interactions; } + + public void getInteractions(String conversationId, int lastN, ActionListener> listener) { + client + .execute( + GetInteractionsAction.INSTANCE, + new GetInteractionsRequest(conversationId, lastN, 0), + new ActionListener() { + @Override + public void onResponse(GetInteractionsResponse getInteractionsResponse) { + listener.onResponse(getInteractionsResponse.getInteractions()); + } + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + } + ); + } } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClient.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClient.java index c49bff254e..21fb95323a 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClient.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClient.java @@ -42,7 +42,7 @@ public ActionFuture predict(String modelId, MLInput mlInput) { } @VisibleForTesting - void predict(String modelId, MLInput mlInput, ActionListener listener) { + public void predict(String modelId, MLInput mlInput, ActionListener listener) { validateMLInput(mlInput, true); MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java index 75b6ca0849..f6cdfec816 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java @@ -25,7 +25,7 @@ import java.util.Map; import org.opensearch.client.Client; -import org.opensearch.common.action.ActionFuture; +import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; @@ -75,20 +75,35 @@ protected void setMlClient(MachineLearningInternalClient mlClient) { * @return */ @Override - public ChatCompletionOutput doChatCompletion(ChatCompletionInput chatCompletionInput) { + public void doChatCompletion(ChatCompletionInput chatCompletionInput, ActionListener listener) { MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(getInputParameters(chatCompletionInput)).build(); MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataset).build(); - ActionFuture future = mlClient.predict(this.openSearchModelId, mlInput); - ModelTensorOutput modelOutput = (ModelTensorOutput) future.actionGet(chatCompletionInput.getTimeoutInSeconds() * 1000); - - // Response from a remote model - Map dataAsMap = modelOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap(); - // log.info("dataAsMap: {}", dataAsMap.toString()); - - // TODO dataAsMap can be null or can contain information such as throttling. Handle non-happy cases. + mlClient.predict(this.openSearchModelId, mlInput, new ActionListener<>() { + @Override + public void onResponse(MLOutput mlOutput) { + // Response from a remote model + Map dataAsMap = ((ModelTensorOutput) mlOutput) + .getMlModelOutputs() + .get(0) + .getMlModelTensors() + .get(0) + .getDataAsMap(); + listener + .onResponse( + buildChatCompletionOutput( + chatCompletionInput.getModelProvider(), + dataAsMap, + chatCompletionInput.getLlmResponseField() + ) + ); + } - return buildChatCompletionOutput(chatCompletionInput.getModelProvider(), dataAsMap, chatCompletionInput.getLlmResponseField()); + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }); } protected Map getInputParameters(ChatCompletionInput chatCompletionInput) { diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java index 87ac8fb6dd..1099b1e21f 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java @@ -17,6 +17,8 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.llm; +import org.opensearch.core.action.ActionListener; + /** * Capabilities of large language models, e.g. completion, embeddings, etc. */ @@ -29,5 +31,5 @@ enum ModelProvider { COHERE } - ChatCompletionOutput doChatCompletion(ChatCompletionInput input); + void doChatCompletion(ChatCompletionInput input, ActionListener listener); } diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java index 297b287997..a89b5c1731 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java @@ -19,10 +19,11 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAResponseProcessor.IllegalArgumentMessage; import java.time.Instant; import java.util.Collections; @@ -40,6 +41,7 @@ import org.opensearch.action.search.SearchResponseSections; import org.opensearch.client.Client; import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.conversation.Interaction; @@ -135,14 +137,17 @@ public void testProcessResponseNoSearchHits() throws Exception { Llm llm = mock(Llm.class); ChatCompletionOutput output = mock(ChatCompletionOutput.class); - when(llm.doChatCompletion(any())).thenReturn(output); + doAnswer(invocation -> { + ((ActionListener) invocation.getArguments()[2]).onResponse(output); + return null; + }).when(llm).doChatCompletion(any(), any()); when(output.getAnswers()).thenReturn(List.of("foo")); processor.setLlm(llm); ArgumentCaptor captor = ArgumentCaptor.forClass(ChatCompletionInput.class); boolean errorThrown = false; try { - SearchResponse res = processor.processResponse(request, response); + processor.processResponseAsync(request, response, null, ActionListener.wrap(r -> {}, e -> {})); } catch (Exception e) { errorThrown = true; } @@ -161,22 +166,23 @@ public void testProcessResponse() throws Exception { ).create(null, "tag", "desc", true, config, null); ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class); - when(memoryClient.getInteractions(any(), anyInt())) - .thenReturn( - List - .of( - new Interaction( - "0", - Instant.now(), - "1", - "question", - "", - "answer", - "foo", - Collections.singletonMap("meta data", "some meta") - ) - ) + List chatHistory = List + .of( + new Interaction( + "0", + Instant.now(), + "1", + "question", + "", + "answer", + "foo", + Collections.singletonMap("meta data", "some meta") + ) ); + doAnswer(invocation -> { + ((ActionListener>) invocation.getArguments()[2]).onResponse(chatHistory); + return null; + }).when(memoryClient).getInteractions(any(), anyInt(), any()); processor.setMemoryClient(memoryClient); SearchRequest request = new SearchRequest(); @@ -217,20 +223,28 @@ public void testProcessResponse() throws Exception { Llm llm = mock(Llm.class); ChatCompletionOutput output = mock(ChatCompletionOutput.class); - when(llm.doChatCompletion(any())).thenReturn(output); + doAnswer(invocation -> { + ((ActionListener) invocation.getArguments()[1]).onResponse(output); + return null; + }).when(llm).doChatCompletion(any(), any()); when(output.getAnswers()).thenReturn(List.of("foo")); processor.setLlm(llm); ArgumentCaptor captor = ArgumentCaptor.forClass(ChatCompletionInput.class); - SearchResponse res = processor.processResponse(request, response); - verify(llm).doChatCompletion(captor.capture()); + processor + .processResponseAsync( + request, + response, + null, + ActionListener.wrap(r -> { assertTrue(r instanceof GenerativeSearchResponse); }, e -> {}) + ); + verify(llm).doChatCompletion(captor.capture(), any()); ChatCompletionInput input = captor.getValue(); assertTrue(input instanceof ChatCompletionInput); - List passages = ((ChatCompletionInput) input).getContexts(); + List passages = input.getContexts(); assertEquals("passage0", passages.get(0)); assertEquals("passage1", passages.get(1)); assertEquals(numHits, passages.size()); - assertTrue(res instanceof GenerativeSearchResponse); } public void testProcessResponseWithErrorFromLlm() throws Exception { @@ -245,22 +259,23 @@ public void testProcessResponseWithErrorFromLlm() throws Exception { ).create(null, "tag", "desc", true, config, null); ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class); - when(memoryClient.getInteractions(any(), anyInt())) - .thenReturn( - List - .of( - new Interaction( - "0", - Instant.now(), - "1", - "question", - "", - "answer", - "foo", - Collections.singletonMap("meta data", "some meta") - ) - ) + List chatHistory = List + .of( + new Interaction( + "0", + Instant.now(), + "1", + "question", + "", + "answer", + "foo", + Collections.singletonMap("meta data", "some meta") + ) ); + doAnswer(invocation -> { + ((ActionListener>) invocation.getArguments()[2]).onResponse(chatHistory); + return null; + }).when(memoryClient).getInteractions(any(), anyInt(), any()); processor.setMemoryClient(memoryClient); SearchRequest request = new SearchRequest(); @@ -301,21 +316,29 @@ public void testProcessResponseWithErrorFromLlm() throws Exception { Llm llm = mock(Llm.class); ChatCompletionOutput output = mock(ChatCompletionOutput.class); - when(llm.doChatCompletion(any())).thenReturn(output); + doAnswer(invocation -> { + ((ActionListener) invocation.getArguments()[1]).onResponse(output); + return null; + }).when(llm).doChatCompletion(any(), any()); when(output.isErrorOccurred()).thenReturn(true); when(output.getErrors()).thenReturn(List.of("something bad has occurred.")); processor.setLlm(llm); ArgumentCaptor captor = ArgumentCaptor.forClass(ChatCompletionInput.class); - SearchResponse res = processor.processResponse(request, response); - verify(llm).doChatCompletion(captor.capture()); + processor + .processResponseAsync( + request, + response, + null, + ActionListener.wrap(r -> { assertTrue(r instanceof GenerativeSearchResponse); }, e -> {}) + ); + verify(llm).doChatCompletion(captor.capture(), any()); ChatCompletionInput input = captor.getValue(); assertTrue(input instanceof ChatCompletionInput); - List passages = ((ChatCompletionInput) input).getContexts(); + List passages = input.getContexts(); assertEquals("passage0", passages.get(0)); assertEquals("passage1", passages.get(1)); assertEquals(numHits, passages.size()); - assertTrue(res instanceof GenerativeSearchResponse); } public void testProcessResponseSmallerContextSize() throws Exception { @@ -330,22 +353,23 @@ public void testProcessResponseSmallerContextSize() throws Exception { ).create(null, "tag", "desc", true, config, null); ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class); - when(memoryClient.getInteractions(any(), anyInt())) - .thenReturn( - List - .of( - new Interaction( - "0", - Instant.now(), - "1", - "question", - "", - "answer", - "foo", - Collections.singletonMap("meta data", "some meta") - ) - ) + List chatHistory = List + .of( + new Interaction( + "0", + Instant.now(), + "1", + "question", + "", + "answer", + "foo", + Collections.singletonMap("meta data", "some meta") + ) ); + doAnswer(invocation -> { + ((ActionListener>) invocation.getArguments()[2]).onResponse(chatHistory); + return null; + }).when(memoryClient).getInteractions(any(), anyInt(), any()); processor.setMemoryClient(memoryClient); SearchRequest request = new SearchRequest(); @@ -387,20 +411,28 @@ public void testProcessResponseSmallerContextSize() throws Exception { Llm llm = mock(Llm.class); ChatCompletionOutput output = mock(ChatCompletionOutput.class); - when(llm.doChatCompletion(any())).thenReturn(output); + doAnswer(invocation -> { + ((ActionListener) invocation.getArguments()[1]).onResponse(output); + return null; + }).when(llm).doChatCompletion(any(), any()); when(output.getAnswers()).thenReturn(List.of("foo")); processor.setLlm(llm); ArgumentCaptor captor = ArgumentCaptor.forClass(ChatCompletionInput.class); - SearchResponse res = processor.processResponse(request, response); - verify(llm).doChatCompletion(captor.capture()); + processor + .processResponseAsync( + request, + response, + null, + ActionListener.wrap(r -> { assertTrue(r instanceof GenerativeSearchResponse); }, e -> {}) + ); + verify(llm).doChatCompletion(captor.capture(), any()); ChatCompletionInput input = captor.getValue(); assertTrue(input instanceof ChatCompletionInput); List passages = ((ChatCompletionInput) input).getContexts(); assertEquals("passage0", passages.get(0)); assertEquals("passage1", passages.get(1)); assertEquals(contextSize, passages.size()); - assertTrue(res instanceof GenerativeSearchResponse); } public void testProcessResponseMissingContextField() throws Exception { @@ -415,22 +447,23 @@ public void testProcessResponseMissingContextField() throws Exception { ).create(null, "tag", "desc", true, config, null); ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class); - when(memoryClient.getInteractions(any(), anyInt())) - .thenReturn( - List - .of( - new Interaction( - "0", - Instant.now(), - "1", - "question", - "", - "answer", - "foo", - Collections.singletonMap("meta data", "some meta") - ) - ) + List chatHistory = List + .of( + new Interaction( + "0", + Instant.now(), + "1", + "question", + "", + "answer", + "foo", + Collections.singletonMap("meta data", "some meta") + ) ); + doAnswer(invocation -> { + ((ActionListener>) invocation.getArguments()[2]).onResponse(chatHistory); + return null; + }).when(memoryClient).getInteractions(any(), anyInt(), any()); processor.setMemoryClient(memoryClient); SearchRequest request = new SearchRequest(); @@ -471,14 +504,17 @@ public void testProcessResponseMissingContextField() throws Exception { Llm llm = mock(Llm.class); ChatCompletionOutput output = mock(ChatCompletionOutput.class); - when(llm.doChatCompletion(any())).thenReturn(output); + doAnswer(invocation -> { + ((ActionListener) invocation.getArguments()[1]).onResponse(output); + return null; + }).when(llm).doChatCompletion(any(), any()); when(output.getAnswers()).thenReturn(List.of("foo")); processor.setLlm(llm); boolean exceptionThrown = false; try { - SearchResponse res = processor.processResponse(request, response); + processor.processResponseAsync(request, response, null, ActionListener.wrap(r -> {}, e -> {})); } catch (Exception e) { exceptionThrown = true; } @@ -527,7 +563,8 @@ public void testProcessorFeatureOffOnOff() throws Exception { featureEnabled001 = false; boolean secondExceptionThrown = false; try { - processor.processResponse(mock(SearchRequest.class), mock(SearchResponse.class)); + processor + .processResponseAsync(mock(SearchRequest.class), mock(SearchResponse.class), null, ActionListener.wrap(r -> {}, e -> {})); } catch (MLException e) { assertEquals(GenerativeQAProcessorConstants.FEATURE_NOT_ENABLED_ERROR_MSG, e.getMessage()); secondExceptionThrown = true; @@ -536,8 +573,8 @@ public void testProcessorFeatureOffOnOff() throws Exception { } public void testProcessResponseNullValueInteractions() throws Exception { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage(IllegalArgumentMessage); + // exceptionRule.expect(IllegalArgumentException.class); + // exceptionRule.expectMessage("Null Pointer in Interactions"); Client client = mock(Client.class); Map config = new HashMap<>(); @@ -550,8 +587,11 @@ public void testProcessResponseNullValueInteractions() throws Exception { ).create(null, "tag", "desc", true, config, null); ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class); - when(memoryClient.getInteractions(any(), anyInt())) - .thenReturn(List.of(new Interaction("0", Instant.now(), "1", null, null, null, null, null))); + List chatHistory = List.of(new Interaction("0", Instant.now(), "1", null, null, null, null, null)); + doAnswer(invocation -> { + ((ActionListener>) invocation.getArguments()[2]).onResponse(chatHistory); + return null; + }).when(memoryClient).getInteractions(any(), anyInt(), any()); processor.setMemoryClient(memoryClient); SearchRequest request = new SearchRequest(); @@ -592,10 +632,18 @@ public void testProcessResponseNullValueInteractions() throws Exception { SearchResponse response = new SearchResponse(internal, null, 1, 1, 0, 1, null, null, null); Llm llm = mock(Llm.class); - when(llm.doChatCompletion(any())).thenThrow(new NullPointerException("Null Pointer in Interactions")); + doAnswer(invocation -> { + ((ActionListener) invocation.getArguments()[1]) + .onFailure(new NullPointerException("Null Pointer in Interactions")); + return null; + }).when(llm).doChatCompletion(any(), any()); + // when(llm.doChatCompletion(any())).thenThrow(new NullPointerException("Null Pointer in Interactions")); processor.setLlm(llm); - SearchResponse res = processor.processResponse(request, response); + processor.processResponseAsync(request, response, null, ActionListener.wrap(r -> {}, e -> { + assertTrue(e instanceof NullPointerException); + // throw new IllegalArgumentException(e.getMessage()); + })); } public void testProcessResponseIllegalArgument() throws Exception { @@ -613,8 +661,23 @@ public void testProcessResponseIllegalArgument() throws Exception { ).create(null, "tag", "desc", true, config, null); ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class); - when(memoryClient.getInteractions(any(), anyInt())) - .thenReturn(List.of(new Interaction("0", Instant.now(), "1", null, null, null, null, null))); + List chatHistory = List + .of( + new Interaction( + "0", + Instant.now(), + "1", + "question", + "", + "answer", + "foo", + Collections.singletonMap("meta data", "some meta") + ) + ); + doAnswer(invocation -> { + ((ActionListener>) invocation.getArguments()[2]).onResponse(chatHistory); + return null; + }).when(memoryClient).getInteractions(any(), anyInt(), any()); processor.setMemoryClient(memoryClient); SearchRequest request = new SearchRequest(); @@ -655,16 +718,18 @@ public void testProcessResponseIllegalArgument() throws Exception { SearchResponse response = new SearchResponse(internal, null, 1, 1, 0, 1, null, null, null); Llm llm = mock(Llm.class); - // when(llm.doChatCompletion(any())).thenThrow(new NullPointerException("Null Pointer in Interactions")); processor.setLlm(llm); - SearchResponse res = processor.processResponse(request, response); + processor + .processResponseAsync( + request, + response, + null, + ActionListener.wrap(r -> { assertTrue(r instanceof GenerativeSearchResponse); }, e -> {}) + ); } public void testProcessResponseOpenSearchException() throws Exception { - exceptionRule.expect(OpenSearchException.class); - exceptionRule.expectMessage("GenerativeQAResponseProcessor failed in precessing response"); - Client client = mock(Client.class); Map config = new HashMap<>(); config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "dummy-model"); @@ -676,8 +741,23 @@ public void testProcessResponseOpenSearchException() throws Exception { ).create(null, "tag", "desc", true, config, null); ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class); - when(memoryClient.getInteractions(any(), anyInt())) - .thenReturn(List.of(new Interaction("0", Instant.now(), "1", null, null, null, null, null))); + List chatHistory = List + .of( + new Interaction( + "0", + Instant.now(), + "1", + "question", + "", + "answer", + "foo", + Collections.singletonMap("meta data", "some meta") + ) + ); + doAnswer(invocation -> { + ((ActionListener>) invocation.getArguments()[2]).onResponse(chatHistory); + return null; + }).when(memoryClient).getInteractions(any(), anyInt(), any()); processor.setMemoryClient(memoryClient); SearchRequest request = new SearchRequest(); @@ -718,9 +798,21 @@ public void testProcessResponseOpenSearchException() throws Exception { SearchResponse response = new SearchResponse(internal, null, 1, 1, 0, 1, null, null, null); Llm llm = mock(Llm.class); - when(llm.doChatCompletion(any())).thenThrow(new RuntimeException()); + // doAnswer(invocation -> { + // ((ActionListener) invocation.getArguments()[1]).onFailure(new RuntimeException()); + // return null; + doThrow(new OpenSearchException("")).when(llm).doChatCompletion(any(), any()); + // when(llm.doChatCompletion(any())).thenThrow(new RuntimeException()); processor.setLlm(llm); - SearchResponse res = processor.processResponse(request, response); + processor + .processResponseAsync( + request, + response, + null, + ActionListener.wrap(r -> { assertTrue(r instanceof GenerativeSearchResponse); }, e -> { + assertTrue(e instanceof OpenSearchException); + }) + ); } } diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java index 5a3978539c..2dc06366f8 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java @@ -36,6 +36,7 @@ import org.mockito.Mock; import org.opensearch.client.Client; import org.opensearch.common.action.ActionFuture; +import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; @@ -121,11 +122,24 @@ public void testChatCompletionApi() throws Exception { Llm.ModelProvider.OPENAI, null ); - ChatCompletionOutput output = connector.doChatCompletion(input); - verify(mlClient, times(1)).predict(any(), captor.capture()); + doAnswer(invocation -> { + ((ActionListener) invocation.getArguments()[2]).onResponse(mlOutput); + return null; + }).when(mlClient).predict(any(), any(), any()); + connector.doChatCompletion(input, new ActionListener<>() { + @Override + public void onResponse(ChatCompletionOutput output) { + assertEquals("answer", output.getAnswers().get(0)); + } + + @Override + public void onFailure(Exception e) { + + } + }); + verify(mlClient, times(1)).predict(any(), captor.capture(), any()); MLInput mlInput = captor.getValue(); assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet); - assertEquals("answer", (String) output.getAnswers().get(0)); } public void testChatCompletionApiForBedrock() throws Exception { @@ -152,11 +166,24 @@ public void testChatCompletionApiForBedrock() throws Exception { Llm.ModelProvider.BEDROCK, null ); - ChatCompletionOutput output = connector.doChatCompletion(input); - verify(mlClient, times(1)).predict(any(), captor.capture()); + doAnswer(invocation -> { + ((ActionListener) invocation.getArguments()[2]).onResponse(mlOutput); + return null; + }).when(mlClient).predict(any(), any(), any()); + connector.doChatCompletion(input, new ActionListener<>() { + @Override + public void onResponse(ChatCompletionOutput output) { + assertEquals("answer", output.getAnswers().get(0)); + } + + @Override + public void onFailure(Exception e) { + + } + }); + verify(mlClient, times(1)).predict(any(), captor.capture(), any()); MLInput mlInput = captor.getValue(); assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet); - assertEquals("answer", (String) output.getAnswers().get(0)); } public void testChatCompletionApiForCohere() throws Exception { @@ -183,11 +210,24 @@ public void testChatCompletionApiForCohere() throws Exception { Llm.ModelProvider.COHERE, null ); - ChatCompletionOutput output = connector.doChatCompletion(input); - verify(mlClient, times(1)).predict(any(), captor.capture()); + doAnswer(invocation -> { + ((ActionListener) invocation.getArguments()[2]).onResponse(mlOutput); + return null; + }).when(mlClient).predict(any(), any(), any()); + connector.doChatCompletion(input, new ActionListener<>() { + @Override + public void onResponse(ChatCompletionOutput output) { + assertEquals("answer", output.getAnswers().get(0)); + } + + @Override + public void onFailure(Exception e) { + + } + }); + verify(mlClient, times(1)).predict(any(), captor.capture(), any()); MLInput mlInput = captor.getValue(); assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet); - assertEquals("answer", (String) output.getAnswers().get(0)); } public void testChatCompletionApiForCohereWithError() throws Exception { @@ -215,12 +255,25 @@ public void testChatCompletionApiForCohereWithError() throws Exception { Llm.ModelProvider.COHERE, null ); - ChatCompletionOutput output = connector.doChatCompletion(input); - verify(mlClient, times(1)).predict(any(), captor.capture()); + doAnswer(invocation -> { + ((ActionListener) invocation.getArguments()[2]).onResponse(mlOutput); + return null; + }).when(mlClient).predict(any(), any(), any()); + connector.doChatCompletion(input, new ActionListener<>() { + @Override + public void onResponse(ChatCompletionOutput output) { + assertTrue(output.isErrorOccurred()); + assertEquals(errorMessage, (String) output.getErrors().get(0)); + } + + @Override + public void onFailure(Exception e) { + + } + }); + verify(mlClient, times(1)).predict(any(), captor.capture(), any()); MLInput mlInput = captor.getValue(); assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet); - assertTrue(output.isErrorOccurred()); - assertEquals(errorMessage, (String) output.getErrors().get(0)); } public void testChatCompletionApiForFoo() throws Exception { @@ -249,11 +302,24 @@ public void testChatCompletionApiForFoo() throws Exception { null, llmRespondField ); - ChatCompletionOutput output = connector.doChatCompletion(input); - verify(mlClient, times(1)).predict(any(), captor.capture()); + doAnswer(invocation -> { + ((ActionListener) invocation.getArguments()[2]).onResponse(mlOutput); + return null; + }).when(mlClient).predict(any(), any(), any()); + connector.doChatCompletion(input, new ActionListener<>() { + @Override + public void onResponse(ChatCompletionOutput output) { + assertEquals("answer", output.getAnswers().get(0)); + } + + @Override + public void onFailure(Exception e) { + + } + }); + verify(mlClient, times(1)).predict(any(), captor.capture(), any()); MLInput mlInput = captor.getValue(); assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet); - assertEquals("answer", (String) output.getAnswers().get(0)); } public void testChatCompletionApiForFooWithError() throws Exception { @@ -283,15 +349,28 @@ public void testChatCompletionApiForFooWithError() throws Exception { null, llmRespondField ); - ChatCompletionOutput output = connector.doChatCompletion(input); - verify(mlClient, times(1)).predict(any(), captor.capture()); + doAnswer(invocation -> { + ((ActionListener) invocation.getArguments()[2]).onResponse(mlOutput); + return null; + }).when(mlClient).predict(any(), any(), any()); + connector.doChatCompletion(input, new ActionListener<>() { + @Override + public void onResponse(ChatCompletionOutput output) { + assertTrue(output.isErrorOccurred()); + assertEquals(errorMessage, (String) output.getErrors().get(0)); + } + + @Override + public void onFailure(Exception e) { + + } + }); + verify(mlClient, times(1)).predict(any(), captor.capture(), any()); MLInput mlInput = captor.getValue(); assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet); - assertTrue(output.isErrorOccurred()); - assertEquals(errorMessage, (String) output.getErrors().get(0)); } - public void testChatCompletionApiForFooWithErrorUnknowMessageField() throws Exception { + public void testChatCompletionApiForFooWithErrorUnknownMessageField() throws Exception { MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class); ArgumentCaptor captor = ArgumentCaptor.forClass(MLInput.class); DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client); @@ -318,15 +397,28 @@ public void testChatCompletionApiForFooWithErrorUnknowMessageField() throws Exce null, llmRespondField ); - ChatCompletionOutput output = connector.doChatCompletion(input); - verify(mlClient, times(1)).predict(any(), captor.capture()); + doAnswer(invocation -> { + ((ActionListener) invocation.getArguments()[2]).onResponse(mlOutput); + return null; + }).when(mlClient).predict(any(), any(), any()); + connector.doChatCompletion(input, new ActionListener<>() { + @Override + public void onResponse(ChatCompletionOutput output) { + assertTrue(output.isErrorOccurred()); + assertEquals("Unknown error or response.", output.getErrors().get(0)); + } + + @Override + public void onFailure(Exception e) { + + } + }); + verify(mlClient, times(1)).predict(any(), captor.capture(), any()); MLInput mlInput = captor.getValue(); assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet); - assertTrue(output.isErrorOccurred()); - assertEquals("Unknown error or response.", (String) output.getErrors().get(0)); } - public void testChatCompletionApiForFooWithErrorUnknowErrorField() throws Exception { + public void testChatCompletionApiForFooWithErrorUnknownErrorField() throws Exception { MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class); ArgumentCaptor captor = ArgumentCaptor.forClass(MLInput.class); DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client); @@ -353,12 +445,25 @@ public void testChatCompletionApiForFooWithErrorUnknowErrorField() throws Except null, llmRespondField ); - ChatCompletionOutput output = connector.doChatCompletion(input); - verify(mlClient, times(1)).predict(any(), captor.capture()); + doAnswer(invocation -> { + ((ActionListener) invocation.getArguments()[2]).onResponse(mlOutput); + return null; + }).when(mlClient).predict(any(), any(), any()); + connector.doChatCompletion(input, new ActionListener<>() { + @Override + public void onResponse(ChatCompletionOutput output) { + assertTrue(output.isErrorOccurred()); + assertEquals("Unknown error or response.", output.getErrors().get(0)); + } + + @Override + public void onFailure(Exception e) { + + } + }); + verify(mlClient, times(1)).predict(any(), captor.capture(), any()); MLInput mlInput = captor.getValue(); assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet); - assertTrue(output.isErrorOccurred()); - assertEquals("Unknown error or response.", (String) output.getErrors().get(0)); } public void testChatCompletionThrowingError() throws Exception { @@ -386,12 +491,26 @@ public void testChatCompletionThrowingError() throws Exception { Llm.ModelProvider.OPENAI, null ); - ChatCompletionOutput output = connector.doChatCompletion(input); - verify(mlClient, times(1)).predict(any(), captor.capture()); + + doAnswer(invocation -> { + ((ActionListener) invocation.getArguments()[2]).onResponse(mlOutput); + return null; + }).when(mlClient).predict(any(), any(), any()); + connector.doChatCompletion(input, new ActionListener<>() { + @Override + public void onResponse(ChatCompletionOutput output) { + assertTrue(output.isErrorOccurred()); + assertEquals(errorMessage, output.getErrors().get(0)); + } + + @Override + public void onFailure(Exception e) { + + } + }); + verify(mlClient, times(1)).predict(any(), captor.capture(), any()); MLInput mlInput = captor.getValue(); assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet); - assertTrue(output.isErrorOccurred()); - assertEquals(errorMessage, (String) output.getErrors().get(0)); } public void testChatCompletionBedrockThrowingError() throws Exception { @@ -419,12 +538,25 @@ public void testChatCompletionBedrockThrowingError() throws Exception { Llm.ModelProvider.BEDROCK, null ); - ChatCompletionOutput output = connector.doChatCompletion(input); - verify(mlClient, times(1)).predict(any(), captor.capture()); + doAnswer(invocation -> { + ((ActionListener) invocation.getArguments()[2]).onResponse(mlOutput); + return null; + }).when(mlClient).predict(any(), any(), any()); + connector.doChatCompletion(input, new ActionListener<>() { + @Override + public void onResponse(ChatCompletionOutput output) { + assertTrue(output.isErrorOccurred()); + assertEquals(errorMessage, output.getErrors().get(0)); + } + + @Override + public void onFailure(Exception e) { + + } + }); + verify(mlClient, times(1)).predict(any(), captor.capture(), any()); MLInput mlInput = captor.getValue(); assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet); - assertTrue(output.isErrorOccurred()); - assertEquals(errorMessage, (String) output.getErrors().get(0)); } public void testIllegalArgument1() { @@ -455,7 +587,7 @@ public void testIllegalArgument1() { null, null ); - ChatCompletionOutput output = connector.doChatCompletion(input); + connector.doChatCompletion(input, ActionListener.wrap(r -> {}, e -> {})); } public void testIllegalArgument2() {