Skip to content

Commit

Permalink
[Backport 2.17] Bug Fix: Fix for rag processor throwing NPE when opti…
Browse files Browse the repository at this point in the history
…onal parameters are not provided (#3066)
  • Loading branch information
pyek-bot authored Oct 8, 2024
1 parent 217afd0 commit 8b5b38e
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,22 @@
*/
package org.opensearch.searchpipelines.questionanswering.generative.ext;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

import java.io.IOException;
import java.util.Objects;

import org.opensearch.core.ParseField;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.xcontent.ObjectParser;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants;

import com.google.common.base.Preconditions;

import lombok.Builder;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
Expand All @@ -45,57 +45,42 @@
@NoArgsConstructor
public class GenerativeQAParameters implements Writeable, ToXContentObject {

private static final ObjectParser<GenerativeQAParameters, Void> PARSER;

// Optional parameter; if provided, conversational memory will be used for RAG
// and the current interaction will be saved in the conversation referenced by this id.
private static final ParseField CONVERSATION_ID = new ParseField("memory_id");
private static final String CONVERSATION_ID = "memory_id";

// Optional parameter; if an LLM model is not set at the search pipeline level, one must be
// provided at the search request level.
private static final ParseField LLM_MODEL = new ParseField("llm_model");
private static final String LLM_MODEL = "llm_model";

// Required parameter; this is sent to LLMs as part of the user prompt.
// TODO support question rewriting when chat history is not used (conversation_id is not provided).
private static final ParseField LLM_QUESTION = new ParseField("llm_question");
private static final String LLM_QUESTION = "llm_question";

// Optional parameter; this parameter controls the number of search results ("contexts") to
// include in the user prompt.
private static final ParseField CONTEXT_SIZE = new ParseField("context_size");
private static final String CONTEXT_SIZE = "context_size";

// Optional parameter; this parameter controls the number of the interactions to include
// in the user prompt.
private static final ParseField INTERACTION_SIZE = new ParseField("message_size");
private static final String INTERACTION_SIZE = "message_size";

// Optional parameter; this parameter controls how long the search pipeline waits for a response
// from a remote inference endpoint before timing out the request.
private static final ParseField TIMEOUT = new ParseField("timeout");
private static final String TIMEOUT = "timeout";

// Optional parameter: this parameter allows request-level customization of the "system" (role) prompt.
private static final ParseField SYSTEM_PROMPT = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_SYSTEM_PROMPT);
private static final String SYSTEM_PROMPT = "system_prompt";

// Optional parameter: this parameter allows request-level customization of the "user" (role) prompt.
private static final ParseField USER_INSTRUCTIONS = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_USER_INSTRUCTIONS);
private static final String USER_INSTRUCTIONS = "user_instructions";

// Optional parameter; this parameter indicates the name of the field in the LLM response
// that contains the chat completion text, i.e. "answer".
private static final ParseField LLM_RESPONSE_FIELD = new ParseField("llm_response_field");
private static final String LLM_RESPONSE_FIELD = "llm_response_field";

public static final int SIZE_NULL_VALUE = -1;

static {
PARSER = new ObjectParser<>("generative_qa_parameters", GenerativeQAParameters::new);
PARSER.declareString(GenerativeQAParameters::setConversationId, CONVERSATION_ID);
PARSER.declareString(GenerativeQAParameters::setLlmModel, LLM_MODEL);
PARSER.declareString(GenerativeQAParameters::setLlmQuestion, LLM_QUESTION);
PARSER.declareStringOrNull(GenerativeQAParameters::setSystemPrompt, SYSTEM_PROMPT);
PARSER.declareStringOrNull(GenerativeQAParameters::setUserInstructions, USER_INSTRUCTIONS);
PARSER.declareIntOrNull(GenerativeQAParameters::setContextSize, SIZE_NULL_VALUE, CONTEXT_SIZE);
PARSER.declareIntOrNull(GenerativeQAParameters::setInteractionSize, SIZE_NULL_VALUE, INTERACTION_SIZE);
PARSER.declareIntOrNull(GenerativeQAParameters::setTimeout, SIZE_NULL_VALUE, TIMEOUT);
PARSER.declareStringOrNull(GenerativeQAParameters::setLlmResponseField, LLM_RESPONSE_FIELD);
}

@Setter
@Getter
private String conversationId;
Expand Down Expand Up @@ -132,6 +117,7 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
@Getter
private String llmResponseField;

@Builder
public GenerativeQAParameters(
String conversationId,
String llmModel,
Expand All @@ -148,7 +134,7 @@ public GenerativeQAParameters(

// TODO: keep this requirement until we can extract the question from the query or from the request processor parameters
// for question rewriting.
Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION.getPreferredName() + " must be provided.");
Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION + " must be provided.");
this.llmQuestion = llmQuestion;
this.systemPrompt = systemPrompt;
this.userInstructions = userInstructions;
Expand All @@ -172,16 +158,45 @@ public GenerativeQAParameters(StreamInput input) throws IOException {

@Override
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
return xContentBuilder
.field(CONVERSATION_ID.getPreferredName(), this.conversationId)
.field(LLM_MODEL.getPreferredName(), this.llmModel)
.field(LLM_QUESTION.getPreferredName(), this.llmQuestion)
.field(SYSTEM_PROMPT.getPreferredName(), this.systemPrompt)
.field(USER_INSTRUCTIONS.getPreferredName(), this.userInstructions)
.field(CONTEXT_SIZE.getPreferredName(), this.contextSize)
.field(INTERACTION_SIZE.getPreferredName(), this.interactionSize)
.field(TIMEOUT.getPreferredName(), this.timeout)
.field(LLM_RESPONSE_FIELD.getPreferredName(), this.llmResponseField);
xContentBuilder.startObject();
if (this.conversationId != null) {
xContentBuilder.field(CONVERSATION_ID, this.conversationId);
}

if (this.llmModel != null) {
xContentBuilder.field(LLM_MODEL, this.llmModel);
}

if (this.llmQuestion != null) {
xContentBuilder.field(LLM_QUESTION, this.llmQuestion);
}

if (this.systemPrompt != null) {
xContentBuilder.field(SYSTEM_PROMPT, this.systemPrompt);
}

if (this.userInstructions != null) {
xContentBuilder.field(USER_INSTRUCTIONS, this.userInstructions);
}

if (this.contextSize != null) {
xContentBuilder.field(CONTEXT_SIZE, this.contextSize);
}

if (this.interactionSize != null) {
xContentBuilder.field(INTERACTION_SIZE, this.interactionSize);
}

if (this.timeout != null) {
xContentBuilder.field(TIMEOUT, this.timeout);
}

if (this.llmResponseField != null) {
xContentBuilder.field(LLM_RESPONSE_FIELD, this.llmResponseField);
}

xContentBuilder.endObject();
return xContentBuilder;
}

@Override
Expand All @@ -200,7 +215,67 @@ public void writeTo(StreamOutput out) throws IOException {
}

public static GenerativeQAParameters parse(XContentParser parser) throws IOException {
return PARSER.parse(parser, null);
String conversationId = null;
String llmModel = null;
String llmQuestion = null;
String systemPrompt = null;
String userInstructions = null;
Integer contextSize = null;
Integer interactionSize = null;
Integer timeout = null;
String llmResponseField = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String field = parser.currentName();
parser.nextToken();

switch (field) {
case CONVERSATION_ID:
conversationId = parser.text();
break;
case LLM_MODEL:
llmModel = parser.text();
break;
case LLM_QUESTION:
llmQuestion = parser.text();
break;
case SYSTEM_PROMPT:
systemPrompt = parser.text();
break;
case USER_INSTRUCTIONS:
userInstructions = parser.text();
break;
case CONTEXT_SIZE:
contextSize = parser.intValue();
break;
case INTERACTION_SIZE:
interactionSize = parser.intValue();
break;
case TIMEOUT:
timeout = parser.intValue();
break;
case LLM_RESPONSE_FIELD:
llmResponseField = parser.text();
break;
default:
parser.skipChildren();
break;
}
}

return GenerativeQAParameters
.builder()
.conversationId(conversationId)
.llmModel(llmModel)
.llmQuestion(llmQuestion)
.systemPrompt(systemPrompt)
.userInstructions(userInstructions)
.contextSize(contextSize)
.interactionSize(interactionSize)
.timeout(timeout)
.llmResponseField(llmResponseField)
.build();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,23 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS;

import java.io.EOFException;
import java.io.IOException;
import java.util.Collections;

import org.junit.Assert;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentHelper;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.search.SearchModule;
import org.opensearch.test.OpenSearchTestCase;

public class GenerativeQAParamExtBuilderTests extends OpenSearchTestCase {
Expand Down Expand Up @@ -107,21 +112,38 @@ public void testMiscMethods() throws IOException {
}

public void testParse() throws IOException {
XContentParser xcParser = mock(XContentParser.class);
when(xcParser.nextToken()).thenReturn(XContentParser.Token.START_OBJECT).thenReturn(XContentParser.Token.END_OBJECT);
GenerativeQAParamExtBuilder builder = GenerativeQAParamExtBuilder.parse(xcParser);
String requiredJsonStr = "{\"llm_question\":\"this is test llm question\"}";

XContentParser parser = XContentType.JSON
.xContent()
.createParser(
new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()),
null,
requiredJsonStr
);

parser.nextToken();
GenerativeQAParamExtBuilder builder = GenerativeQAParamExtBuilder.parse(parser);
assertNotNull(builder);
assertNotNull(builder.getParams());
GenerativeQAParameters params = builder.getParams();
Assert.assertEquals("this is test llm question", params.getLlmQuestion());
}

public void testXContentRoundTrip() throws IOException {
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", null, null, null, null);
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
extBuilder.setParams(param1);

XContentType xContentType = randomFrom(XContentType.values());
BytesReference serialized = XContentHelper.toXContent(extBuilder, xContentType, true);
XContentBuilder builder = XContentBuilder.builder(xContentType.xContent());
builder = extBuilder.toXContent(builder, EMPTY_PARAMS);
BytesReference serialized = BytesReference.bytes(builder);

XContentParser parser = createParser(xContentType.xContent(), serialized);
parser.nextToken();
GenerativeQAParamExtBuilder deserialized = GenerativeQAParamExtBuilder.parse(parser);

assertEquals(extBuilder, deserialized);
GenerativeQAParameters parameters = deserialized.getParams();
assertTrue(GenerativeQAParameters.SIZE_NULL_VALUE == parameters.getContextSize());
Expand All @@ -133,10 +155,16 @@ public void testXContentRoundTripAllValues() throws IOException {
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", 1, 2, 3, null);
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
extBuilder.setParams(param1);

XContentType xContentType = randomFrom(XContentType.values());
BytesReference serialized = XContentHelper.toXContent(extBuilder, xContentType, true);
XContentBuilder builder = XContentBuilder.builder(xContentType.xContent());
builder = extBuilder.toXContent(builder, EMPTY_PARAMS);
BytesReference serialized = BytesReference.bytes(builder);

XContentParser parser = createParser(xContentType.xContent(), serialized);
parser.nextToken();
GenerativeQAParamExtBuilder deserialized = GenerativeQAParamExtBuilder.parse(parser);

assertEquals(extBuilder, deserialized);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,18 @@ public void testToXConent() throws IOException {
assertNotNull(parameters.toXContent(builder, null));
}

public void testToXConentAllOptionalParameters() throws IOException {
public void testToXContentEmptyParams() throws IOException {
GenerativeQAParameters parameters = new GenerativeQAParameters();
XContent xc = mock(XContent.class);
OutputStream os = mock(OutputStream.class);
XContentGenerator generator = mock(XContentGenerator.class);
when(xc.createGenerator(any(), any(), any())).thenReturn(generator);
XContentBuilder builder = new XContentBuilder(xc, os);
parameters.toXContent(builder, null);
assertNotNull(parameters.toXContent(builder, null));
}

public void testToXContentAllOptionalParameters() throws IOException {
String conversationId = "a";
String llmModel = "b";
String llmQuestion = "c";
Expand Down

0 comments on commit 8b5b38e

Please sign in to comment.