Skip to content

Commit

Permalink
Fix: Gracefully handle error when generative_qa_parameters is not pro…
Browse files Browse the repository at this point in the history
…vided (#3100)

* fix: gracefully handle error when generative_qa_parameters is not provided

Signed-off-by: Pavan Yekbote <[email protected]>

* fix: spotless apply

Signed-off-by: Pavan Yekbote <[email protected]>

* docs: adding documentation link to error message

Signed-off-by: Pavan Yekbote <[email protected]>

* tests: adding UT to test null params

Signed-off-by: Pavan Yekbote <[email protected]>

---------

Signed-off-by: Pavan Yekbote <[email protected]>
(cherry picked from commit 0f7481e)
  • Loading branch information
pyek-bot authored and github-actions[bot] committed Oct 15, 2024
1 parent 5e2863f commit b6736dc
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,8 @@ public class GenerativeQAProcessorConstants {
.boolSetting("plugins.ml_commons.rag_pipeline_feature_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final String FEATURE_NOT_ENABLED_ERROR_MSG = RAG_PIPELINE_FEATURE_ENABLED.getKey() + " is not enabled.";

public static final String RAG_NULL_GEN_QA_PARAMS_ERROR_MSG = "generative_qa_parameters not found."
+ " Please provide ext.generative_qa_parameters to proceed."
+ " For more info, refer: https://opensearch.org/docs/latest/search-plugins/conversational-search/#step-6-use-the-pipeline-for-rag";
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.opensearch.searchpipelines.questionanswering.generative;

import static org.opensearch.ingest.ConfigurationUtils.newConfigurationException;
import static org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants.RAG_NULL_GEN_QA_PARAMS_ERROR_MSG;

import java.time.Duration;
import java.time.Instant;
Expand Down Expand Up @@ -126,6 +127,9 @@ public void processResponseAsync(
}

GenerativeQAParameters params = GenerativeQAParamUtil.getGenerativeQAParameters(request);
if (params == null) {
throw new IllegalArgumentException(RAG_NULL_GEN_QA_PARAMS_ERROR_MSG);
}

Integer t = params.getTimeout();
if (t == null || t == GenerativeQAParameters.SIZE_NULL_VALUE) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants.RAG_NULL_GEN_QA_PARAMS_ERROR_MSG;

import java.time.Instant;
import java.util.Collections;
Expand Down Expand Up @@ -646,6 +647,77 @@ public void testProcessResponseNullValueInteractions() throws Exception {
}));
}

public void testProcessResponseIllegalArgumentForNullParams() throws Exception {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage(RAG_NULL_GEN_QA_PARAMS_ERROR_MSG);

Client client = mock(Client.class);
Map<String, Object> config = new HashMap<>();
config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "dummy-model");
config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text"));

GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory(
client,
alwaysOn
).create(null, "tag", "desc", true, config, null);

ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class);
List<Interaction> chatHistory = List
.of(
new Interaction(
"0",
Instant.now(),
"1",
"question",
"",
"answer",
"foo",
Collections.singletonMap("meta data", "some meta")
)
);
doAnswer(invocation -> {
((ActionListener<List<Interaction>>) invocation.getArguments()[2]).onResponse(chatHistory);
return null;
}).when(memoryClient).getInteractions(any(), anyInt(), any());
processor.setMemoryClient(memoryClient);

SearchRequest request = new SearchRequest();
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
extBuilder.setParams(null);
request.source(sourceBuilder);
sourceBuilder.ext(List.of(extBuilder));

int numHits = 10;
SearchHit[] hitsArray = new SearchHit[numHits];
for (int i = 0; i < numHits; i++) {
XContentBuilder sourceContent = JsonXContent
.contentBuilder()
.startObject()
.field("_id", String.valueOf(i))
.field("text", "passage" + i)
.field("title", "This is the title for document " + i)
.endObject();
hitsArray[i] = new SearchHit(i, "doc" + i, Map.of(), Map.of());
hitsArray[i].sourceRef(BytesReference.bytes(sourceContent));
}

SearchHits searchHits = new SearchHits(hitsArray, null, 1.0f);
SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0);
SearchResponse response = new SearchResponse(internal, null, 1, 1, 0, 1, null, null, null);

Llm llm = mock(Llm.class);
processor.setLlm(llm);

processor
.processResponseAsync(
request,
response,
null,
ActionListener.wrap(r -> { assertTrue(r instanceof GenerativeSearchResponse); }, e -> {})
);
}

public void testProcessResponseIllegalArgument() throws Exception {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("llm_model cannot be null.");
Expand Down

0 comments on commit b6736dc

Please sign in to comment.