From bc6ec93ac5090a1135da0a8e74f294c433152c6e Mon Sep 17 00:00:00 2001 From: Pavan Yekbote Date: Fri, 11 Oct 2024 10:24:47 -0700 Subject: [PATCH] [Backport 2.12] Bug Fix: Fix for rag processor throwing NPE (#3089) * [Backport 2.17] Bug Fix: Fix for rag processor throwing NPE when optional parameters are not provided (#3066) (#3076) (cherry picked from commit 8b5b38e4b28e182a0cdfbf1c55c6ef00e6663a57) Co-authored-by: Pavan Yekbote * fix: spotless changes Signed-off-by: Pavan Yekbote --------- Signed-off-by: Pavan Yekbote Co-authored-by: opensearch-trigger-bot[bot] <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> --- .../ext/GenerativeQAParameters.java | 113 +++++++++++++----- .../ext/GenerativeQAParamExtBuilderTests.java | 42 +++++-- .../ext/GenerativeQAParametersTests.java | 13 +- 3 files changed, 131 insertions(+), 37 deletions(-) diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java index 2710b26a57..ba2bf0fe86 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java @@ -17,21 +17,22 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.ext; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + import java.io.IOException; import java.util.Objects; -import org.opensearch.core.ParseField; import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; -import org.opensearch.core.xcontent.ObjectParser; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import com.google.common.base.Preconditions; +import lombok.Builder; import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; @@ -44,44 +45,32 @@ @NoArgsConstructor public class GenerativeQAParameters implements Writeable, ToXContentObject { - private static final ObjectParser PARSER; - // Optional parameter; if provided, conversational memory will be used for RAG // and the current interaction will be saved in the conversation referenced by this id. - private static final ParseField CONVERSATION_ID = new ParseField("conversation_id"); + private static final String CONVERSATION_ID = "memory_id"; // Optional parameter; if an LLM model is not set at the search pipeline level, one must be // provided at the search request level. - private static final ParseField LLM_MODEL = new ParseField("llm_model"); + private static final String LLM_MODEL = "llm_model"; // Required parameter; this is sent to LLMs as part of the user prompt. // TODO support question rewriting when chat history is not used (conversation_id is not provided). - private static final ParseField LLM_QUESTION = new ParseField("llm_question"); + private static final String LLM_QUESTION = "llm_question"; // Optional parameter; this parameter controls the number of search results ("contexts") to // include in the user prompt. - private static final ParseField CONTEXT_SIZE = new ParseField("context_size"); + private static final String CONTEXT_SIZE = "context_size"; // Optional parameter; this parameter controls the number of the interactions to include // in the user prompt. - private static final ParseField INTERACTION_SIZE = new ParseField("interaction_size"); + private static final String INTERACTION_SIZE = "message_size"; // Optional parameter; this parameter controls how long the search pipeline waits for a response // from a remote inference endpoint before timing out the request. - private static final ParseField TIMEOUT = new ParseField("timeout"); + private static final String TIMEOUT = "timeout"; public static final int SIZE_NULL_VALUE = -1; - static { - PARSER = new ObjectParser<>("generative_qa_parameters", GenerativeQAParameters::new); - PARSER.declareString(GenerativeQAParameters::setConversationId, CONVERSATION_ID); - PARSER.declareString(GenerativeQAParameters::setLlmModel, LLM_MODEL); - PARSER.declareString(GenerativeQAParameters::setLlmQuestion, LLM_QUESTION); - PARSER.declareIntOrNull(GenerativeQAParameters::setContextSize, SIZE_NULL_VALUE, CONTEXT_SIZE); - PARSER.declareIntOrNull(GenerativeQAParameters::setInteractionSize, SIZE_NULL_VALUE, INTERACTION_SIZE); - PARSER.declareIntOrNull(GenerativeQAParameters::setTimeout, SIZE_NULL_VALUE, TIMEOUT); - } - @Setter @Getter private String conversationId; @@ -106,6 +95,7 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject { @Getter private Integer timeout; + @Builder public GenerativeQAParameters( String conversationId, String llmModel, @@ -119,7 +109,7 @@ public GenerativeQAParameters( // TODO: keep this requirement until we can extract the question from the query or from the request processor parameters // for question rewriting. - Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION.getPreferredName() + " must be provided."); + Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION + " must be provided."); this.llmQuestion = llmQuestion; this.contextSize = (contextSize == null) ? SIZE_NULL_VALUE : contextSize; this.interactionSize = (interactionSize == null) ? SIZE_NULL_VALUE : interactionSize; @@ -137,13 +127,33 @@ public GenerativeQAParameters(StreamInput input) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { - return xContentBuilder - .field(CONVERSATION_ID.getPreferredName(), this.conversationId) - .field(LLM_MODEL.getPreferredName(), this.llmModel) - .field(LLM_QUESTION.getPreferredName(), this.llmQuestion) - .field(CONTEXT_SIZE.getPreferredName(), this.contextSize) - .field(INTERACTION_SIZE.getPreferredName(), this.interactionSize) - .field(TIMEOUT.getPreferredName(), this.timeout); + xContentBuilder.startObject(); + if (this.conversationId != null) { + xContentBuilder.field(CONVERSATION_ID, this.conversationId); + } + + if (this.llmModel != null) { + xContentBuilder.field(LLM_MODEL, this.llmModel); + } + + if (this.llmQuestion != null) { + xContentBuilder.field(LLM_QUESTION, this.llmQuestion); + } + + if (this.contextSize != null) { + xContentBuilder.field(CONTEXT_SIZE, this.contextSize); + } + + if (this.interactionSize != null) { + xContentBuilder.field(INTERACTION_SIZE, this.interactionSize); + } + + if (this.timeout != null) { + xContentBuilder.field(TIMEOUT, this.timeout); + } + + xContentBuilder.endObject(); + return xContentBuilder; } @Override @@ -159,7 +169,52 @@ public void writeTo(StreamOutput out) throws IOException { } public static GenerativeQAParameters parse(XContentParser parser) throws IOException { - return PARSER.parse(parser, null); + String conversationId = null; + String llmModel = null; + String llmQuestion = null; + Integer contextSize = null; + Integer interactionSize = null; + Integer timeout = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String field = parser.currentName(); + parser.nextToken(); + + switch (field) { + case CONVERSATION_ID: + conversationId = parser.text(); + break; + case LLM_MODEL: + llmModel = parser.text(); + break; + case LLM_QUESTION: + llmQuestion = parser.text(); + break; + case CONTEXT_SIZE: + contextSize = parser.intValue(); + break; + case INTERACTION_SIZE: + interactionSize = parser.intValue(); + break; + case TIMEOUT: + timeout = parser.intValue(); + break; + default: + parser.skipChildren(); + break; + } + } + + return GenerativeQAParameters + .builder() + .conversationId(conversationId) + .llmModel(llmModel) + .llmQuestion(llmQuestion) + .contextSize(contextSize) + .interactionSize(interactionSize) + .timeout(timeout) + .build(); } @Override diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java index 5aeb1e804f..aabdc35cf1 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java @@ -21,18 +21,23 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import java.io.EOFException; import java.io.IOException; +import java.util.Collections; +import org.junit.Assert; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.xcontent.XContentHelper; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchModule; import org.opensearch.test.OpenSearchTestCase; public class GenerativeQAParamExtBuilderTests extends OpenSearchTestCase { @@ -97,21 +102,38 @@ public void testMiscMethods() throws IOException { } public void testParse() throws IOException { - XContentParser xcParser = mock(XContentParser.class); - when(xcParser.nextToken()).thenReturn(XContentParser.Token.START_OBJECT).thenReturn(XContentParser.Token.END_OBJECT); - GenerativeQAParamExtBuilder builder = GenerativeQAParamExtBuilder.parse(xcParser); + String requiredJsonStr = "{\"llm_question\":\"this is test llm question\"}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + requiredJsonStr + ); + + parser.nextToken(); + GenerativeQAParamExtBuilder builder = GenerativeQAParamExtBuilder.parse(parser); assertNotNull(builder); assertNotNull(builder.getParams()); + GenerativeQAParameters params = builder.getParams(); + Assert.assertEquals("this is test llm question", params.getLlmQuestion()); } public void testXContentRoundTrip() throws IOException { GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", null, null, null); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(param1); + XContentType xContentType = randomFrom(XContentType.values()); - BytesReference serialized = XContentHelper.toXContent(extBuilder, xContentType, true); + XContentBuilder builder = XContentBuilder.builder(xContentType.xContent()); + builder = extBuilder.toXContent(builder, EMPTY_PARAMS); + BytesReference serialized = BytesReference.bytes(builder); + XContentParser parser = createParser(xContentType.xContent(), serialized); + parser.nextToken(); GenerativeQAParamExtBuilder deserialized = GenerativeQAParamExtBuilder.parse(parser); + assertEquals(extBuilder, deserialized); GenerativeQAParameters parameters = deserialized.getParams(); assertTrue(GenerativeQAParameters.SIZE_NULL_VALUE == parameters.getContextSize()); @@ -123,10 +145,16 @@ public void testXContentRoundTripAllValues() throws IOException { GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", 1, 2, 3); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(param1); + XContentType xContentType = randomFrom(XContentType.values()); - BytesReference serialized = XContentHelper.toXContent(extBuilder, xContentType, true); + XContentBuilder builder = XContentBuilder.builder(xContentType.xContent()); + builder = extBuilder.toXContent(builder, EMPTY_PARAMS); + BytesReference serialized = BytesReference.bytes(builder); + XContentParser parser = createParser(xContentType.xContent(), serialized); + parser.nextToken(); GenerativeQAParamExtBuilder deserialized = GenerativeQAParamExtBuilder.parse(parser); + assertEquals(extBuilder, deserialized); } diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java index 600b1c7a19..713339da71 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java @@ -148,7 +148,18 @@ public void testToXConent() throws IOException { assertNotNull(parameters.toXContent(builder, null)); } - public void testToXConentAllOptionalParameters() throws IOException { + public void testToXContentEmptyParams() throws IOException { + GenerativeQAParameters parameters = new GenerativeQAParameters(); + XContent xc = mock(XContent.class); + OutputStream os = mock(OutputStream.class); + XContentGenerator generator = mock(XContentGenerator.class); + when(xc.createGenerator(any(), any(), any())).thenReturn(generator); + XContentBuilder builder = new XContentBuilder(xc, os); + parameters.toXContent(builder, null); + assertNotNull(parameters.toXContent(builder, null)); + } + + public void testToXContentAllOptionalParameters() throws IOException { String conversationId = "a"; String llmModel = "b"; String llmQuestion = "c";