diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java index 8fb58c1b9a..7cf490c664 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java @@ -359,7 +359,6 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase { + " \"ext\": {\n" + " \"generative_qa_parameters\": {\n" + " \"llm_model\": \"%s\",\n" - + " \"llm_question\": \"%s\",\n" + " \"system_prompt\": \"%s\",\n" + " \"user_instructions\": \"%s\",\n" + " \"context_size\": %d,\n" @@ -378,8 +377,6 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase { + " \"ext\": {\n" + " \"generative_qa_parameters\": {\n" + " \"llm_model\": \"%s\",\n" - + " \"llm_question\": \"%s\",\n" - // + " \"system_prompt\": \"%s\",\n" + " \"user_instructions\": \"%s\",\n" + " \"context_size\": %d,\n" + " \"message_size\": %d,\n" @@ -723,8 +720,12 @@ public void testBM25WithBedrock() throws Exception { public void testBM25WithBedrockConverse() throws Exception { // Skip test if key is null if (AWS_ACCESS_KEY_ID == null) { + System.out.println("Skipping testBM25WithBedrockConverse because AWS_ACCESS_KEY_ID is null"); return; } + + System.out.println("Running testBM25WithBedrockConverse"); + Response response = createConnector(BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -775,8 +776,11 @@ public void testBM25WithBedrockConverse() throws Exception { public void testBM25WithBedrockConverseUsingLlmMessages() throws Exception { // Skip test if key is null if (AWS_ACCESS_KEY_ID == null) { + System.out.println("Skipping testBM25WithBedrockConverseUsingLlmMessages because AWS_ACCESS_KEY_ID is null"); return; } + System.out.println("Running testBM25WithBedrockConverseUsingLlmMessages"); + Response response = createConnector(BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT2); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -835,8 +839,11 @@ public void testBM25WithBedrockConverseUsingLlmMessages() throws Exception { public void testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat() throws Exception { // Skip test if key is null if (AWS_ACCESS_KEY_ID == null) { + System.out.println("Skipping testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat because AWS_ACCESS_KEY_ID is null"); return; } + + System.out.println("Running testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat"); Response response = createConnector(BEDROCK_DOCUMENT_CONVERSE_CONNECTOR_BLUEPRINT2); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -894,8 +901,11 @@ public void testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat() throws public void testBM25WithOpenAIWithConversation() throws Exception { // Skip test if key is null if (OPENAI_KEY == null) { + System.out.println("Skipping testBM25WithOpenAIWithConversation because OPENAI_KEY is null"); return; } + System.out.println("Running testBM25WithOpenAIWithConversation"); + Response response = createConnector(OPENAI_CONNECTOR_BLUEPRINT); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -951,8 +961,11 @@ public void testBM25WithOpenAIWithConversation() throws Exception { public void testBM25WithOpenAIWithConversationAndImage() throws Exception { // Skip test if key is null if (OPENAI_KEY == null) { + System.out.println("Skipping testBM25WithOpenAIWithConversationAndImage because OPENAI_KEY is null"); return; } + System.out.println("Running testBM25WithOpenAIWithConversationAndImage"); + Response response = createConnector(OPENAI_4o_CONNECTOR_BLUEPRINT); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -1245,7 +1258,6 @@ private Response performSearch(String indexName, String pipeline, int size, Sear requestParameters.source, requestParameters.match, requestParameters.llmModel, - requestParameters.llmQuestion, requestParameters.systemPrompt, requestParameters.userInstructions, requestParameters.contextSize, @@ -1268,8 +1280,6 @@ private Response performSearch(String indexName, String pipeline, int size, Sear requestParameters.source, requestParameters.match, requestParameters.llmModel, - requestParameters.llmQuestion, - // requestParameters.systemPrompt, requestParameters.userInstructions, requestParameters.contextSize, requestParameters.interactionSize, @@ -1309,7 +1319,6 @@ private Response performSearch(String indexName, String pipeline, int size, Sear requestParameters.source, requestParameters.match, requestParameters.llmModel, - requestParameters.llmQuestion, requestParameters.systemPrompt, requestParameters.userInstructions, requestParameters.contextSize, diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java index d5ec0e47c1..6ff89093b3 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java @@ -167,9 +167,11 @@ public GenerativeQAParameters( this.conversationId = conversationId; this.llmModel = llmModel; - // TODO: keep this requirement until we can extract the question from the query or from the request processor parameters - // for question rewriting. - Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION + " must be provided."); + Preconditions + .checkArgument( + !(Strings.isNullOrEmpty(llmQuestion) && (llmMessages == null || llmMessages.isEmpty())), + "At least one of " + LLM_QUESTION + " or " + LLM_MESSAGES_FIELD + " must be provided." + ); this.llmQuestion = llmQuestion; this.systemPrompt = systemPrompt; this.userInstructions = userInstructions; @@ -185,7 +187,7 @@ public GenerativeQAParameters( public GenerativeQAParameters(StreamInput input) throws IOException { this.conversationId = input.readOptionalString(); this.llmModel = input.readOptionalString(); - this.llmQuestion = input.readString(); + this.llmQuestion = input.readOptionalString(); this.systemPrompt = input.readOptionalString(); this.userInstructions = input.readOptionalString(); this.contextSize = input.readInt(); @@ -246,9 +248,7 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(conversationId); out.writeOptionalString(llmModel); - - Preconditions.checkNotNull(llmQuestion, "llm_question must not be null."); - out.writeString(llmQuestion); + out.writeOptionalString(llmQuestion); out.writeOptionalString(systemPrompt); out.writeOptionalString(userInstructions); out.writeInt(contextSize); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java index 2772884f11..8a5ade0072 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java @@ -121,8 +121,7 @@ public void testMiscMethods() throws IOException { StreamOutput so = mock(StreamOutput.class); builder1.writeTo(so); - verify(so, times(5)).writeOptionalString(any()); - verify(so, times(1)).writeString(any()); + verify(so, times(6)).writeOptionalString(any()); } public void testParse() throws IOException {