Skip to content

Commit

Permalink
Allow llmQuestion to be optional when llmMessages is used. (Issue ope…
Browse files Browse the repository at this point in the history
…nsearch-project#3067)

Signed-off-by: Austin Lee <[email protected]>
  • Loading branch information
austintlee committed Oct 8, 2024
1 parent 74c211e commit 138def6
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase {
+ " \"ext\": {\n"
+ " \"generative_qa_parameters\": {\n"
+ " \"llm_model\": \"%s\",\n"
+ " \"llm_question\": \"%s\",\n"
// + " \"llm_question\": \"%s\",\n"
+ " \"system_prompt\": \"%s\",\n"
+ " \"user_instructions\": \"%s\",\n"
+ " \"context_size\": %d,\n"
Expand All @@ -378,7 +378,7 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase {
+ " \"ext\": {\n"
+ " \"generative_qa_parameters\": {\n"
+ " \"llm_model\": \"%s\",\n"
+ " \"llm_question\": \"%s\",\n"
// + " \"llm_question\": \"%s\",\n"
// + " \"system_prompt\": \"%s\",\n"
+ " \"user_instructions\": \"%s\",\n"
+ " \"context_size\": %d,\n"
Expand Down Expand Up @@ -723,8 +723,12 @@ public void testBM25WithBedrock() throws Exception {
public void testBM25WithBedrockConverse() throws Exception {
// Skip test if key is null
if (AWS_ACCESS_KEY_ID == null) {
System.out.println("Skipping testBM25WithBedrockConverse because AWS_ACCESS_KEY_ID is null");
return;
}

System.out.println("Running testBM25WithBedrockConverse");

Response response = createConnector(BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
Expand Down Expand Up @@ -775,8 +779,11 @@ public void testBM25WithBedrockConverse() throws Exception {
public void testBM25WithBedrockConverseUsingLlmMessages() throws Exception {
// Skip test if key is null
if (AWS_ACCESS_KEY_ID == null) {
System.out.println("Skipping testBM25WithBedrockConverseUsingLlmMessages because AWS_ACCESS_KEY_ID is null");
return;
}
System.out.println("Running testBM25WithBedrockConverseUsingLlmMessages");

Response response = createConnector(BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT2);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
Expand Down Expand Up @@ -835,8 +842,11 @@ public void testBM25WithBedrockConverseUsingLlmMessages() throws Exception {
public void testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat() throws Exception {
// Skip test if key is null
if (AWS_ACCESS_KEY_ID == null) {
System.out.println("Skipping testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat because AWS_ACCESS_KEY_ID is null");
return;
}

System.out.println("Running testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat");
Response response = createConnector(BEDROCK_DOCUMENT_CONVERSE_CONNECTOR_BLUEPRINT2);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
Expand Down Expand Up @@ -894,8 +904,11 @@ public void testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat() throws
public void testBM25WithOpenAIWithConversation() throws Exception {
// Skip test if key is null
if (OPENAI_KEY == null) {
System.out.println("Skipping testBM25WithOpenAIWithConversation because OPENAI_KEY is null");
return;
}
System.out.println("Running testBM25WithOpenAIWithConversation");

Response response = createConnector(OPENAI_CONNECTOR_BLUEPRINT);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
Expand Down Expand Up @@ -951,8 +964,11 @@ public void testBM25WithOpenAIWithConversation() throws Exception {
public void testBM25WithOpenAIWithConversationAndImage() throws Exception {
// Skip test if key is null
if (OPENAI_KEY == null) {
System.out.println("Skipping testBM25WithOpenAIWithConversationAndImage because OPENAI_KEY is null");
return;
}
System.out.println("Running testBM25WithOpenAIWithConversationAndImage");

Response response = createConnector(OPENAI_4o_CONNECTOR_BLUEPRINT);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
Expand Down Expand Up @@ -1245,7 +1261,6 @@ private Response performSearch(String indexName, String pipeline, int size, Sear
requestParameters.source,
requestParameters.match,
requestParameters.llmModel,
requestParameters.llmQuestion,
requestParameters.systemPrompt,
requestParameters.userInstructions,
requestParameters.contextSize,
Expand All @@ -1268,8 +1283,6 @@ private Response performSearch(String indexName, String pipeline, int size, Sear
requestParameters.source,
requestParameters.match,
requestParameters.llmModel,
requestParameters.llmQuestion,
// requestParameters.systemPrompt,
requestParameters.userInstructions,
requestParameters.contextSize,
requestParameters.interactionSize,
Expand Down Expand Up @@ -1309,7 +1322,6 @@ private Response performSearch(String indexName, String pipeline, int size, Sear
requestParameters.source,
requestParameters.match,
requestParameters.llmModel,
requestParameters.llmQuestion,
requestParameters.systemPrompt,
requestParameters.userInstructions,
requestParameters.contextSize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,11 @@ public GenerativeQAParameters(
this.conversationId = conversationId;
this.llmModel = llmModel;

// TODO: keep this requirement until we can extract the question from the query or from the request processor parameters
// for question rewriting.
Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION + " must be provided.");
Preconditions
.checkArgument(
!(Strings.isNullOrEmpty(llmQuestion) && (llmMessages == null || llmMessages.isEmpty())),
"At least one of " + LLM_QUESTION + " or " + LLM_MESSAGES_FIELD + " must be provided."
);
this.llmQuestion = llmQuestion;
this.systemPrompt = systemPrompt;
this.userInstructions = userInstructions;
Expand All @@ -185,7 +187,7 @@ public GenerativeQAParameters(
public GenerativeQAParameters(StreamInput input) throws IOException {
this.conversationId = input.readOptionalString();
this.llmModel = input.readOptionalString();
this.llmQuestion = input.readString();
this.llmQuestion = input.readOptionalString();
this.systemPrompt = input.readOptionalString();
this.userInstructions = input.readOptionalString();
this.contextSize = input.readInt();
Expand Down Expand Up @@ -246,9 +248,7 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(conversationId);
out.writeOptionalString(llmModel);

Preconditions.checkNotNull(llmQuestion, "llm_question must not be null.");
out.writeString(llmQuestion);
out.writeOptionalString(llmQuestion);
out.writeOptionalString(systemPrompt);
out.writeOptionalString(userInstructions);
out.writeInt(contextSize);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ public void testMiscMethods() throws IOException {

StreamOutput so = mock(StreamOutput.class);
builder1.writeTo(so);
verify(so, times(5)).writeOptionalString(any());
verify(so, times(1)).writeString(any());
verify(so, times(6)).writeOptionalString(any());
}

public void testParse() throws IOException {
Expand Down

0 comments on commit 138def6

Please sign in to comment.