Skip to content

Commit

Permalink
Add support for context_size and include 'interaction_id' in SearchRe…
Browse files Browse the repository at this point in the history
…sponse. [Issue opensearch-project#1372]

Signed-off-by: Austin Lee <[email protected]>
  • Loading branch information
austintlee committed Sep 25, 2023
1 parent 2c8cc02 commit e499ec5
Show file tree
Hide file tree
Showing 22 changed files with 564 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ public class GenerativeQAProcessorConstants {
// The field in search results that contain the context to be sent to the LLM.
public static final String CONFIG_NAME_CONTEXT_FIELD_LIST = "context_field_list";

public static final String CONFIG_NAME_SYSTEM_PROMPT = "system_prompt";
public static final String CONFIG_NAME_USER_INSTRUCTIONS = "user_instructions";

public static final Setting<Boolean> RAG_PIPELINE_FEATURE_ENABLED = Setting
.boolSetting("plugins.ml_commons.rag_pipeline_feature_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
*/
package org.opensearch.searchpipelines.questionanswering.generative;

import com.google.gson.Gson;
import com.google.gson.JsonArray;
import lombok.Getter;
import lombok.Setter;
Expand All @@ -41,10 +40,13 @@
import org.opensearch.searchpipelines.questionanswering.generative.llm.ModelLocator;
import org.opensearch.searchpipelines.questionanswering.generative.prompt.PromptUtil;

import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.BooleanSupplier;

import static org.opensearch.ingest.ConfigurationUtils.newConfigurationException;
Expand All @@ -58,11 +60,16 @@ public class GenerativeQAResponseProcessor extends AbstractProcessor implements

private static final int DEFAULT_CHAT_HISTORY_WINDOW = 10;

private static final int MAX_PROCESSOR_TIME_IN_SECONDS = 60;

// TODO Add "interaction_count". This is how far back in chat history we want to go back when calling LLM.

private final String llmModel;
private final List<String> contextFields;

private final String systemPrompt;
private final String userInstructions;

@Setter
private ConversationalMemoryClient memoryClient;

Expand All @@ -74,10 +81,12 @@ public class GenerativeQAResponseProcessor extends AbstractProcessor implements
private final BooleanSupplier featureFlagSupplier;

protected GenerativeQAResponseProcessor(Client client, String tag, String description, boolean ignoreFailure,
Llm llm, String llmModel, List<String> contextFields, BooleanSupplier supplier) {
Llm llm, String llmModel, List<String> contextFields, String systemPrompt, String userInstructions, BooleanSupplier supplier) {
super(tag, description, ignoreFailure);
this.llmModel = llmModel;
this.contextFields = contextFields;
this.systemPrompt = systemPrompt;
this.userInstructions = userInstructions;
this.llm = llm;
this.memoryClient = new ConversationalMemoryClient(client);
this.featureFlagSupplier = supplier;
Expand All @@ -93,45 +102,94 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
}

GenerativeQAParameters params = GenerativeQAParamUtil.getGenerativeQAParameters(request);

Integer timeout = params.getTimeout();
if (timeout == null || timeout == GenerativeQAParameters.SIZE_NULL_VALUE) {
timeout = MAX_PROCESSOR_TIME_IN_SECONDS;
}
log.info("Timeout for this request: {} seconds.", timeout);

String llmQuestion = params.getLlmQuestion();
String llmModel = params.getLlmModel() == null ? this.llmModel : params.getLlmModel();
if (llmModel == null) {
throw new IllegalArgumentException("llm_model cannot be null.");
}
String conversationId = params.getConversationId();
log.info("LLM question: {}, LLM model {}, conversation id: {}", llmQuestion, llmModel, conversationId);
List<Interaction> chatHistory = (conversationId == null) ? Collections.emptyList() : memoryClient.getInteractions(conversationId, DEFAULT_CHAT_HISTORY_WINDOW);
List<String> searchResults = getSearchResults(response);
ChatCompletionOutput output = llm.doChatCompletion(LlmIOUtil.createChatCompletionInput(llmModel, llmQuestion, chatHistory, searchResults));
String answer = (String) output.getAnswers().get(0);
Instant start = Instant.now();
Integer interactionSize = params.getInteractionSize();
if (interactionSize == null || interactionSize == GenerativeQAParameters.SIZE_NULL_VALUE) {
interactionSize = DEFAULT_CHAT_HISTORY_WINDOW;
}
log.info("Using interaction size of {}", interactionSize);
List<Interaction> chatHistory = (conversationId == null) ? Collections.emptyList() : memoryClient.getInteractions(conversationId, interactionSize);
log.info("Retrieved chat history. ({})", getDuration(start));

Integer topN = params.getContextSize();
if (topN == null) {
topN = GenerativeQAParameters.SIZE_NULL_VALUE;
}
List<String> searchResults = getSearchResults(response, topN);

log.info("system_prompt: {}", systemPrompt);
log.info("user_instructions: {}", userInstructions);
start = Instant.now();
ChatCompletionOutput output = llm.doChatCompletion(LlmIOUtil.createChatCompletionInput(systemPrompt, userInstructions, llmModel,
llmQuestion, chatHistory, searchResults, timeout));
log.info("doChatCompletion complete. ({})", getDuration(start));

String answer = null;
String errorMessage = null;
String interactionId = null;
if (conversationId != null) {
interactionId = memoryClient.createInteraction(conversationId, llmQuestion, PromptUtil.DEFAULT_CHAT_COMPLETION_PROMPT_TEMPLATE, answer,
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, jsonArrayToString(searchResults));
if (output.isErrorOccurred()) {
errorMessage = output.getErrors().get(0);
} else {
answer = (String) output.getAnswers().get(0);

if (conversationId != null) {
start = Instant.now();
interactionId = memoryClient.createInteraction(conversationId,
llmQuestion,
PromptUtil.getPromptTemplate(systemPrompt, userInstructions),
answer,
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
jsonArrayToString(searchResults)
);
log.info("Created a new interaction: {} ({})", interactionId, getDuration(start));
}
}

return insertAnswer(response, answer, interactionId);
return insertAnswer(response, answer, errorMessage, interactionId);
}

long getDuration(Instant start) {
return Duration.between(start, Instant.now()).toMillis();
}

@Override
public String getType() {
return GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE;
}

private SearchResponse insertAnswer(SearchResponse response, String answer, String interactionId) {
private SearchResponse insertAnswer(SearchResponse response, String answer, String errorMessage, String interactionId) {

// TODO return the interaction id in the response.

return new GenerativeSearchResponse(answer, response.getInternalResponse(), response.getScrollId(), response.getTotalShards(), response.getSuccessfulShards(),
response.getSkippedShards(), response.getSuccessfulShards(), response.getShardFailures(), response.getClusters());
return new GenerativeSearchResponse(answer, errorMessage, response.getInternalResponse(), response.getScrollId(), response.getTotalShards(), response.getSuccessfulShards(),
response.getSkippedShards(), response.getSuccessfulShards(), response.getShardFailures(), response.getClusters(), interactionId);
}

private List<String> getSearchResults(SearchResponse response) {
private List<String> getSearchResults(SearchResponse response, Integer topN) {
List<String> searchResults = new ArrayList<>();
for (SearchHit hit : response.getHits().getHits()) {
Map<String, Object> docSourceMap = hit.getSourceAsMap();
SearchHit[] hits = response.getHits().getHits();
int total = hits.length;
int end = (topN != GenerativeQAParameters.SIZE_NULL_VALUE) ? Math.min(topN, total) : total;
for (int i = 0; i < end; i++) {
Map<String, Object> docSourceMap = hits[i].getSourceAsMap();
for (String contextField : contextFields) {
Object context = docSourceMap.get(contextField);
if (context == null) {
log.error("Context " + contextField + " not found in search hit " + hit);
log.error("Context " + contextField + " not found in search hit " + hits[i]);
// TODO throw a more meaningful error here?
throw new RuntimeException();
}
Expand Down Expand Up @@ -189,14 +247,27 @@ public SearchResponseProcessor create(
"required property can't be empty."
);
}
log.info("model_id {}, llm_model {}, context_field_list {}", modelId, llmModel, contextFields);
String systemPrompt = ConfigurationUtils.readOptionalStringProperty(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_SYSTEM_PROMPT
);
String userInstructions = ConfigurationUtils.readOptionalStringProperty(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_USER_INSTRUCTIONS
);
log.info("model_id {}, llm_model {}, context_field_list {}, system_prompt {}, user_instructions {}",
modelId, llmModel, contextFields, systemPrompt, userInstructions);
return new GenerativeQAResponseProcessor(client,
tag,
description,
ignoreFailure,
ModelLocator.getLlm(modelId, client),
llmModel,
contextFields,
systemPrompt,
userInstructions,
featureFlagSupplier
);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Objects;

/**
* This is an extension of SearchResponse that adds LLM-generated answers to search responses in a dedicated "ext" section.
Expand All @@ -33,22 +34,32 @@ public class GenerativeSearchResponse extends SearchResponse {

private static final String EXT_SECTION_NAME = "ext";
private static final String GENERATIVE_QA_ANSWER_FIELD_NAME = "answer";
private static final String GENERATIVE_QA_ERROR_FIELD_NAME = "error";
private static final String INTERACTION_ID_FIELD_NAME = "interaction_id";

private final String answer;
private String errorMessage;
private final String interactionId;

public GenerativeSearchResponse(
String answer,
String errorMessage,
SearchResponseSections internalResponse,
String scrollId,
int totalShards,
int successfulShards,
int skippedShards,
long tookInMillis,
ShardSearchFailure[] shardFailures,
Clusters clusters
Clusters clusters,
String interactionId
) {
super(internalResponse, scrollId, totalShards, successfulShards, skippedShards, tookInMillis, shardFailures, clusters);
this.answer = answer;
if (answer == null) {
this.errorMessage = Objects.requireNonNull(errorMessage, "If answer is not given, errorMessage must be provided.");
}
this.interactionId = interactionId;
}

@Override
Expand All @@ -57,7 +68,16 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
innerToXContent(builder, params);
/* start of ext */ builder.startObject(EXT_SECTION_NAME);
/* start of our stuff */ builder.startObject(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE);
/* body of our stuff */ builder.field(GENERATIVE_QA_ANSWER_FIELD_NAME, this.answer);
if (answer == null) {
builder.field(GENERATIVE_QA_ERROR_FIELD_NAME, this.errorMessage);
} else {
/* body of our stuff */
builder.field(GENERATIVE_QA_ANSWER_FIELD_NAME, this.answer);
if (this.interactionId != null) {
/* interaction id */
builder.field(INTERACTION_ID_FIELD_NAME, this.interactionId);
}
}
/* end of our stuff */ builder.endObject();
/* end of ext */ builder.endObject();
builder.endObject();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,13 @@
import lombok.extern.log4j.Log4j2;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.client.Client;
import org.opensearch.common.action.ActionFuture;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.index.reindex.ScrollableHitSource;
import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.ml.memory.action.conversation.CreateConversationAction;
import org.opensearch.ml.memory.action.conversation.CreateConversationRequest;
Expand All @@ -37,6 +42,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.function.BiFunction;

/**
* An OpenSearch client wrapper for conversational memory related calls.
Expand All @@ -46,12 +52,13 @@
public class ConversationalMemoryClient {

private final static Logger logger = LogManager.getLogger();
private final static long DEFAULT_TIMEOUT_IN_MILLIS = 10_000l;

private Client client;

public String createConversation(String name) {

CreateConversationResponse response = client.execute(CreateConversationAction.INSTANCE, new CreateConversationRequest(name)).actionGet();
CreateConversationResponse response = client.execute(CreateConversationAction.INSTANCE, new CreateConversationRequest(name)).actionGet(DEFAULT_TIMEOUT_IN_MILLIS);
log.info("createConversation: id: {}", response.getId());
return response.getId();
}
Expand All @@ -61,7 +68,7 @@ public String createInteraction(String conversationId, String input, String prom
Preconditions.checkNotNull(input);
Preconditions.checkNotNull(response);
CreateInteractionResponse res = client.execute(CreateInteractionAction.INSTANCE,
new CreateInteractionRequest(conversationId, input, promptTemplate, response, origin, additionalInfo)).actionGet();
new CreateInteractionRequest(conversationId, input, promptTemplate, response, origin, additionalInfo)).actionGet(DEFAULT_TIMEOUT_IN_MILLIS);
log.info("createInteraction: interactionId: {}", res.getId());
return res.getId();
}
Expand All @@ -78,7 +85,7 @@ public List<Interaction> getInteractions(String conversationId, int lastN) {
int maxResults = lastN;
do {
GetInteractionsResponse response =
client.execute(GetInteractionsAction.INSTANCE, new GetInteractionsRequest(conversationId, maxResults, from)).actionGet();
client.execute(GetInteractionsAction.INSTANCE, new GetInteractionsRequest(conversationId, maxResults, from)).actionGet(DEFAULT_TIMEOUT_IN_MILLIS);
List<Interaction> list = response.getInteractions();
if (list != null && !CollectionUtils.isEmpty(list)) {
interactions.addAll(list);
Expand All @@ -97,6 +104,4 @@ public List<Interaction> getInteractions(String conversationId, int lastN) {

return interactions;
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import lombok.AccessLevel;
import lombok.RequiredArgsConstructor;
import lombok.experimental.FieldDefaults;
import lombok.extern.log4j.Log4j2;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.client.Client;
import org.opensearch.common.action.ActionFuture;
Expand All @@ -27,6 +28,7 @@
*/
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@RequiredArgsConstructor
@Log4j2
public class MachineLearningInternalClient {

Client client;
Expand Down
Loading

0 comments on commit e499ec5

Please sign in to comment.