Skip to content

Commit

Permalink
Add request level parameters for system_prompt and user_instructions.
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Lee <[email protected]>
  • Loading branch information
austintlee committed Mar 19, 2024
1 parent 833c9c1 commit b3b2f2d
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT {
+ " \"generative_qa_parameters\": {\n"
+ " \"llm_model\": \"%s\",\n"
+ " \"llm_question\": \"%s\",\n"
+ " \"system_prompt\": \"%s\",\n"
+ " \"user_instructions\": \"%s\",\n"
+ " \"context_size\": %d,\n"
+ " \"message_size\": %d,\n"
+ " \"timeout\": %d\n"
Expand All @@ -188,6 +190,8 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT {
+ " \"llm_model\": \"%s\",\n"
+ " \"llm_question\": \"%s\",\n"
+ " \"memory_id\": \"%s\",\n"
+ " \"system_prompt\": \"%s\",\n"
+ " \"user_instructions\": \"%s\",\n"
+ " \"context_size\": %d,\n"
+ " \"message_size\": %d,\n"
+ " \"timeout\": %d\n"
Expand Down Expand Up @@ -308,6 +312,8 @@ public void testBM25WithOpenAI() throws Exception {
requestParameters.match = "president";
requestParameters.llmModel = OPENAI_MODEL;
requestParameters.llmQuestion = "who is lincoln";
requestParameters.systemPrompt = "You are great at answering questions";
requestParameters.userInstructions = "Follow my instructions as best you can";
requestParameters.contextSize = 5;
requestParameters.interactionSize = 5;
requestParameters.timeout = 60;
Expand Down Expand Up @@ -527,6 +533,8 @@ private Response performSearch(String indexName, String pipeline, int size, Sear
requestParameters.match,
requestParameters.llmModel,
requestParameters.llmQuestion,
requestParameters.systemPrompt,
requestParameters.userInstructions,
requestParameters.contextSize,
requestParameters.interactionSize,
requestParameters.timeout
Expand All @@ -541,6 +549,8 @@ private Response performSearch(String indexName, String pipeline, int size, Sear
requestParameters.llmModel,
requestParameters.llmQuestion,
requestParameters.conversationId,
requestParameters.systemPrompt,
requestParameters.userInstructions,
requestParameters.contextSize,
requestParameters.interactionSize,
requestParameters.timeout
Expand Down Expand Up @@ -581,6 +591,8 @@ static class SearchRequestParameters {
String match;
String llmModel;
String llmQuestion;
String systemPrompt;
String userInstructions;
int contextSize;
int interactionSize;
int timeout;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,18 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
}
List<String> searchResults = getSearchResults(response, topN);

// See if the prompt is being overridden at the request level.
String effectiveSystemPrompt = systemPrompt;
String effectiveUserInstructions = userInstructions;
if (params.getSystemPrompt() != null) {
effectiveSystemPrompt = params.getSystemPrompt();
}
if (params.getUserInstructions() != null) {
effectiveUserInstructions = params.getUserInstructions();
}
log.info("system_prompt: {}", effectiveSystemPrompt);
log.info("user_instructions: {}", effectiveUserInstructions);

start = Instant.now();
try {
ChatCompletionOutput output = llm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants;

import com.google.common.base.Preconditions;

Expand Down Expand Up @@ -70,13 +71,19 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
// from a remote inference endpoint before timing out the request.
private static final ParseField TIMEOUT = new ParseField("timeout");

private static final ParseField SYSTEM_PROMPT = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_SYSTEM_PROMPT);

private static final ParseField USER_INSTRUCTIONS = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_USER_INSTRUCTIONS);

public static final int SIZE_NULL_VALUE = -1;

static {
PARSER = new ObjectParser<>("generative_qa_parameters", GenerativeQAParameters::new);
PARSER.declareString(GenerativeQAParameters::setConversationId, CONVERSATION_ID);
PARSER.declareString(GenerativeQAParameters::setLlmModel, LLM_MODEL);
PARSER.declareString(GenerativeQAParameters::setLlmQuestion, LLM_QUESTION);
PARSER.declareStringOrNull(GenerativeQAParameters::setSystemPrompt, SYSTEM_PROMPT);
PARSER.declareStringOrNull(GenerativeQAParameters::setUserInstructions, USER_INSTRUCTIONS);
PARSER.declareIntOrNull(GenerativeQAParameters::setContextSize, SIZE_NULL_VALUE, CONTEXT_SIZE);
PARSER.declareIntOrNull(GenerativeQAParameters::setInteractionSize, SIZE_NULL_VALUE, INTERACTION_SIZE);
PARSER.declareIntOrNull(GenerativeQAParameters::setTimeout, SIZE_NULL_VALUE, TIMEOUT);
Expand Down Expand Up @@ -106,10 +113,20 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
@Getter
private Integer timeout;

@Setter
@Getter
private String systemPrompt;

@Setter
@Getter
private String userInstructions;

public GenerativeQAParameters(
String conversationId,
String llmModel,
String llmQuestion,
String systemPrompt,
String userInstructions,
Integer contextSize,
Integer interactionSize,
Integer timeout
Expand All @@ -121,6 +138,8 @@ public GenerativeQAParameters(
// for question rewriting.
Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION.getPreferredName() + " must be provided.");
this.llmQuestion = llmQuestion;
this.systemPrompt = systemPrompt;
this.userInstructions = userInstructions;
this.contextSize = (contextSize == null) ? SIZE_NULL_VALUE : contextSize;
this.interactionSize = (interactionSize == null) ? SIZE_NULL_VALUE : interactionSize;
this.timeout = (timeout == null) ? SIZE_NULL_VALUE : timeout;
Expand All @@ -130,6 +149,8 @@ public GenerativeQAParameters(StreamInput input) throws IOException {
this.conversationId = input.readOptionalString();
this.llmModel = input.readOptionalString();
this.llmQuestion = input.readString();
this.systemPrompt = input.readOptionalString();
this.userInstructions = input.readOptionalString();
this.contextSize = input.readInt();
this.interactionSize = input.readInt();
this.timeout = input.readInt();
Expand All @@ -141,6 +162,8 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
.field(CONVERSATION_ID.getPreferredName(), this.conversationId)
.field(LLM_MODEL.getPreferredName(), this.llmModel)
.field(LLM_QUESTION.getPreferredName(), this.llmQuestion)
.field(SYSTEM_PROMPT.getPreferredName(), this.systemPrompt)
.field(USER_INSTRUCTIONS.getPreferredName(), this.userInstructions)
.field(CONTEXT_SIZE.getPreferredName(), this.contextSize)
.field(INTERACTION_SIZE.getPreferredName(), this.interactionSize)
.field(TIMEOUT.getPreferredName(), this.timeout);
Expand All @@ -153,6 +176,8 @@ public void writeTo(StreamOutput out) throws IOException {

Preconditions.checkNotNull(llmQuestion, "llm_question must not be null.");
out.writeString(llmQuestion);
out.writeOptionalString(systemPrompt);
out.writeOptionalString(userInstructions);
out.writeInt(contextSize);
out.writeInt(interactionSize);
out.writeInt(timeout);
Expand All @@ -175,6 +200,8 @@ public boolean equals(Object o) {
return Objects.equals(this.conversationId, other.getConversationId())
&& Objects.equals(this.llmModel, other.getLlmModel())
&& Objects.equals(this.llmQuestion, other.getLlmQuestion())
&& Objects.equals(this.systemPrompt, other.getSystemPrompt())
&& Objects.equals(this.userInstructions, other.getUserInstructions())
&& (this.contextSize == other.getContextSize())
&& (this.interactionSize == other.getInteractionSize())
&& (this.timeout == other.getTimeout());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,16 @@ public void testProcessResponseNoSearchHits() throws Exception {

SearchRequest request = new SearchRequest(); // mock(SearchRequest.class);
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); // mock(SearchSourceBuilder.class);
GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind.", null, null, null);
GenerativeQAParameters params = new GenerativeQAParameters(
"12345",
"llm_model",
"You are kind.",
"system_prompt",
"user_instructions",
null,
null,
null
);
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
extBuilder.setParams(params);
request.source(sourceBuilder);
Expand Down Expand Up @@ -170,7 +179,16 @@ public void testProcessResponse() throws Exception {

SearchRequest request = new SearchRequest();
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind.", null, null, null);
GenerativeQAParameters params = new GenerativeQAParameters(
"12345",
"llm_model",
"You are kind.",
"system_promt",
"user_insturctions",
null,
null,
null
);
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
extBuilder.setParams(params);
request.source(sourceBuilder);
Expand Down Expand Up @@ -245,7 +263,16 @@ public void testProcessResponseSmallerContextSize() throws Exception {
SearchRequest request = new SearchRequest();
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
int contextSize = 5;
GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind.", contextSize, null, null);
GenerativeQAParameters params = new GenerativeQAParameters(
"12345",
"llm_model",
"You are kind.",
"system_prompt",
"user_instructions",
contextSize,
null,
null
);
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
extBuilder.setParams(params);
request.source(sourceBuilder);
Expand Down Expand Up @@ -319,7 +346,16 @@ public void testProcessResponseMissingContextField() throws Exception {

SearchRequest request = new SearchRequest();
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind.", null, null, null);
GenerativeQAParameters params = new GenerativeQAParameters(
"12345",
"llm_model",
"You are kind.",
"system_prompt",
"user_instructions",
null,
null,
null
);
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
extBuilder.setParams(params);
request.source(sourceBuilder);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,16 @@ public class GenerativeQAParamExtBuilderTests extends OpenSearchTestCase {

public void testCtor() throws IOException {
GenerativeQAParamExtBuilder builder = new GenerativeQAParamExtBuilder();
GenerativeQAParameters parameters = new GenerativeQAParameters("conversation_id", "model_id", "question", null, null, null);
GenerativeQAParameters parameters = new GenerativeQAParameters(
"conversation_id",
"model_id",
"question",
"system_promtp",
"user_instructions",
null,
null,
null
);
builder.setParams(parameters);
assertEquals(parameters, builder.getParams());

Expand Down Expand Up @@ -79,8 +88,8 @@ public int read() throws IOException {
}

public void testMiscMethods() throws IOException {
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", null, null, null);
GenerativeQAParameters param2 = new GenerativeQAParameters("a", "b", "d", null, null, null);
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", null, null, null);
GenerativeQAParameters param2 = new GenerativeQAParameters("a", "b", "d", "s", "u", null, null, null);
GenerativeQAParamExtBuilder builder1 = new GenerativeQAParamExtBuilder();
GenerativeQAParamExtBuilder builder2 = new GenerativeQAParamExtBuilder();
builder1.setParams(param1);
Expand All @@ -92,7 +101,7 @@ public void testMiscMethods() throws IOException {

StreamOutput so = mock(StreamOutput.class);
builder1.writeTo(so);
verify(so, times(2)).writeOptionalString(any());
verify(so, times(4)).writeOptionalString(any());
verify(so, times(1)).writeString(any());
}

Expand All @@ -105,7 +114,7 @@ public void testParse() throws IOException {
}

public void testXContentRoundTrip() throws IOException {
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", null, null, null);
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", null, null, null);
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
extBuilder.setParams(param1);
XContentType xContentType = randomFrom(XContentType.values());
Expand All @@ -120,7 +129,7 @@ public void testXContentRoundTrip() throws IOException {
}

public void testXContentRoundTripAllValues() throws IOException {
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", 1, 2, 3);
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", 1, 2, 3);
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
extBuilder.setParams(param1);
XContentType xContentType = randomFrom(XContentType.values());
Expand All @@ -131,7 +140,7 @@ public void testXContentRoundTripAllValues() throws IOException {
}

public void testStreamRoundTrip() throws IOException {
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", null, null, null);
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", null, null, null);
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
extBuilder.setParams(param1);
BytesStreamOutput bso = new BytesStreamOutput();
Expand All @@ -145,7 +154,7 @@ public void testStreamRoundTrip() throws IOException {
}

public void testStreamRoundTripAllValues() throws IOException {
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", 1, 2, 3);
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", 1, 2, 3);
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
extBuilder.setParams(param1);
BytesStreamOutput bso = new BytesStreamOutput();
Expand Down
Loading

0 comments on commit b3b2f2d

Please sign in to comment.