Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate RAG pipeline to async processing. #2345

Merged
merged 2 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,23 @@
import java.util.Map;
import java.util.function.BooleanSupplier;

import org.opensearch.OpenSearchException;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.search.SearchHit;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.searchpipelines.questionanswering.generative.client.ConversationalMemoryClient;
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamUtil;
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters;
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionInput;
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput;
import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm;
import org.opensearch.searchpipelines.questionanswering.generative.llm.LlmIOUtil;
Expand All @@ -65,8 +68,6 @@ public class GenerativeQAResponseProcessor extends AbstractProcessor implements

private static final int DEFAULT_PROCESSOR_TIME_IN_SECONDS = 30;

// TODO Add "interaction_count". This is how far back in chat history we want to go back when calling LLM.

private final String llmModel;
private final List<String> contextFields;

Expand Down Expand Up @@ -106,20 +107,32 @@ protected GenerativeQAResponseProcessor(
}

@Override
public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception {
public SearchResponse processResponse(SearchRequest searchRequest, SearchResponse searchResponse) {
// Synchronous call is no longer supported because this execution can occur on a transport thread.
throw new UnsupportedOperationException();
}

log.info("Entering processResponse.");
@Override
public void processResponseAsync(
SearchRequest request,
SearchResponse response,
PipelineProcessingContext requestContext,
ActionListener<SearchResponse> responseListener
) {
log.debug("Entering processResponse.");

if (!this.featureFlagSupplier.getAsBoolean()) {
throw new MLException(GenerativeQAProcessorConstants.FEATURE_NOT_ENABLED_ERROR_MSG);
}

GenerativeQAParameters params = GenerativeQAParamUtil.getGenerativeQAParameters(request);

Integer timeout = params.getTimeout();
if (timeout == null || timeout == GenerativeQAParameters.SIZE_NULL_VALUE) {
timeout = DEFAULT_PROCESSOR_TIME_IN_SECONDS;
Integer t = params.getTimeout();
if (t == null || t == GenerativeQAParameters.SIZE_NULL_VALUE) {
t = DEFAULT_PROCESSOR_TIME_IN_SECONDS;
}
final int timeout = t;
log.debug("Timeout for this request: {} seconds.", timeout);

String llmQuestion = params.getLlmQuestion();
String llmModel = params.getLlmModel() == null ? this.llmModel : params.getLlmModel();
Expand All @@ -128,14 +141,15 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
}
String conversationId = params.getConversationId();

if (conversationId != null && !Strings.hasText(conversationId)) {
throw new IllegalArgumentException("Empty conversation_id is not allowed.");
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you managed to test this?

You should probably invoke responseListener.onFailure(). Otherwise, the current thread may throw and the listener would sit there waiting for a response.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I have a test for this, but it does not go through the REST layer. I may need an IT test.

}
Instant start = Instant.now();
Integer interactionSize = params.getInteractionSize();
if (interactionSize == null || interactionSize == GenerativeQAParameters.SIZE_NULL_VALUE) {
interactionSize = DEFAULT_CHAT_HISTORY_WINDOW;
}
List<Interaction> chatHistory = (conversationId == null)
? Collections.emptyList()
: memoryClient.getInteractions(conversationId, interactionSize);
log.debug("Using interaction size of {}", interactionSize);

Integer topN = params.getContextSize();
if (topN == null) {
Expand All @@ -153,10 +167,32 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
effectiveUserInstructions = params.getUserInstructions();
}

start = Instant.now();
try {
ChatCompletionOutput output = llm
.doChatCompletion(
final List<Interaction> chatHistory = new ArrayList<>();
if (conversationId == null) {
doChatCompletion(
LlmIOUtil
.createChatCompletionInput(
systemPrompt,
userInstructions,
llmModel,
llmQuestion,
chatHistory,
searchResults,
timeout,
params.getLlmResponseField()
),
null,
llmQuestion,
searchResults,
response,
responseListener
);
} else {
final Instant memoryStart = Instant.now();
memoryClient.getInteractions(conversationId, interactionSize, ActionListener.wrap(r -> {
log.debug("getInteractions complete. ({})", getDuration(memoryStart));
chatHistory.addAll(r);
doChatCompletion(
LlmIOUtil
.createChatCompletionInput(
systemPrompt,
Expand All @@ -167,53 +203,82 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
searchResults,
timeout,
params.getLlmResponseField()
)
),
conversationId,
llmQuestion,
searchResults,
response,
responseListener
);
log.info("doChatCompletion complete. ({})", getDuration(start));
}, responseListener::onFailure));
}
}

String answer = null;
String errorMessage = null;
String interactionId = null;
if (output.isErrorOccurred()) {
errorMessage = output.getErrors().get(0);
} else {
answer = (String) output.getAnswers().get(0);
private void doChatCompletion(
ChatCompletionInput input,
String conversationId,
String llmQuestion,
List<String> searchResults,
SearchResponse response,
ActionListener<SearchResponse> responseListener
) {

final Instant chatStart = Instant.now();
llm.doChatCompletion(input, new ActionListener<>() {
@Override
public void onResponse(ChatCompletionOutput output) {
log.debug("doChatCompletion complete. ({})", getDuration(chatStart));

final String answer = getAnswer(output);
final String errorMessage = getError(output);

if (conversationId != null) {
start = Instant.now();
interactionId = memoryClient
final Instant memoryStart = Instant.now();
memoryClient
.createInteraction(
conversationId,
llmQuestion,
PromptUtil.getPromptTemplate(systemPrompt, userInstructions),
answer,
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
Collections.singletonMap("metadata", jsonArrayToString(searchResults))
Collections.singletonMap("metadata", jsonArrayToString(searchResults)),
ActionListener.wrap(r -> {
responseListener.onResponse(insertAnswer(response, answer, errorMessage, r));
log.info("Created a new interaction: {} ({})", r, getDuration(memoryStart));
}, responseListener::onFailure)
);
log.info("Created a new interaction: {} ({})", interactionId, getDuration(start));

} else {
responseListener.onResponse(insertAnswer(response, answer, errorMessage, null));
}

}

return insertAnswer(response, answer, errorMessage, interactionId);
} catch (NullPointerException nullPointerException) {
throw new IllegalArgumentException(IllegalArgumentMessage);
} catch (Exception e) {
throw new OpenSearchException("GenerativeQAResponseProcessor failed in precessing response");
}
}
@Override
public void onFailure(Exception e) {
responseListener.onFailure(e);
}

long getDuration(Instant start) {
return Duration.between(start, Instant.now()).toMillis();
private String getError(ChatCompletionOutput output) {
return output.isErrorOccurred() ? output.getErrors().get(0) : null;
}

private String getAnswer(ChatCompletionOutput output) {
return output.isErrorOccurred() ? null : (String) output.getAnswers().get(0);
}
});
}

@Override
public String getType() {
return GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE;
}

private SearchResponse insertAnswer(SearchResponse response, String answer, String errorMessage, String interactionId) {
private long getDuration(Instant start) {
return Duration.between(start, Instant.now()).toMillis();
}

// TODO return the interaction id in the response.
private SearchResponse insertAnswer(SearchResponse response, String answer, String errorMessage, String interactionId) {

return new GenerativeSearchResponse(
answer,
Expand All @@ -240,9 +305,7 @@ private List<String> getSearchResults(SearchResponse response, Integer topN) {
for (String contextField : contextFields) {
Object context = docSourceMap.get(contextField);
if (context == null) {
log.error("Context " + contextField + " not found in search hit " + hits[i]);
// TODO throw a more meaningful error here?
throw new RuntimeException();
throw new RuntimeException("Context " + contextField + " not found in search hit " + hits[i]);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, you need to make sure that this exception gets propagated to the listener. (I don't remember if that's covered by ActionListener.wrap(). Maybe?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll check and also test.

}
searchResults.add(context.toString());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.ml.memory.action.conversation.CreateConversationAction;
Expand Down Expand Up @@ -83,6 +84,33 @@ public String createInteraction(
return res.getId();
}

public void createInteraction(
String conversationId,
String input,
String promptTemplate,
String response,
String origin,
Map<String, String> additionalInfo,
ActionListener<String> listener
) {
client
.execute(
CreateInteractionAction.INSTANCE,
new CreateInteractionRequest(conversationId, input, promptTemplate, response, origin, additionalInfo),
new ActionListener<CreateInteractionResponse>() {
@Override
public void onResponse(CreateInteractionResponse createInteractionResponse) {
listener.onResponse(createInteractionResponse.getId());
}

@Override
public void onFailure(Exception e) {
listener.onFailure(e);
}
}
);
}

public List<Interaction> getInteractions(String conversationId, int lastN) {

Preconditions.checkArgument(lastN > 0, "lastN must be at least 1.");
Expand Down Expand Up @@ -113,4 +141,23 @@ public List<Interaction> getInteractions(String conversationId, int lastN) {

return interactions;
}

public void getInteractions(String conversationId, int lastN, ActionListener<List<Interaction>> listener) {
client
.execute(
GetInteractionsAction.INSTANCE,
new GetInteractionsRequest(conversationId, lastN, 0),
new ActionListener<GetInteractionsResponse>() {
@Override
public void onResponse(GetInteractionsResponse getInteractionsResponse) {
listener.onResponse(getInteractionsResponse.getInteractions());
}

@Override
public void onFailure(Exception e) {
listener.onFailure(e);
}
}
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public ActionFuture<MLOutput> predict(String modelId, MLInput mlInput) {
}

@VisibleForTesting
void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
public void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
validateMLInput(mlInput, true);

MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import java.util.Map;

import org.opensearch.client.Client;
import org.opensearch.common.action.ActionFuture;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
Expand Down Expand Up @@ -75,20 +75,35 @@ protected void setMlClient(MachineLearningInternalClient mlClient) {
* @return
*/
@Override
public ChatCompletionOutput doChatCompletion(ChatCompletionInput chatCompletionInput) {

public void doChatCompletion(ChatCompletionInput chatCompletionInput, ActionListener<ChatCompletionOutput> listener) {
MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(getInputParameters(chatCompletionInput)).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataset).build();
ActionFuture<MLOutput> future = mlClient.predict(this.openSearchModelId, mlInput);
ModelTensorOutput modelOutput = (ModelTensorOutput) future.actionGet(chatCompletionInput.getTimeoutInSeconds() * 1000);

// Response from a remote model
Map<String, ?> dataAsMap = modelOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap();
// log.info("dataAsMap: {}", dataAsMap.toString());

// TODO dataAsMap can be null or can contain information such as throttling. Handle non-happy cases.
mlClient.predict(this.openSearchModelId, mlInput, new ActionListener<>() {
@Override
public void onResponse(MLOutput mlOutput) {
// Response from a remote model
Map<String, ?> dataAsMap = ((ModelTensorOutput) mlOutput)
.getMlModelOutputs()
.get(0)
.getMlModelTensors()
.get(0)
.getDataAsMap();
listener
.onResponse(
buildChatCompletionOutput(
chatCompletionInput.getModelProvider(),
dataAsMap,
chatCompletionInput.getLlmResponseField()
)
);
}

return buildChatCompletionOutput(chatCompletionInput.getModelProvider(), dataAsMap, chatCompletionInput.getLlmResponseField());
@Override
public void onFailure(Exception e) {
listener.onFailure(e);
}
});
}

protected Map<String, String> getInputParameters(ChatCompletionInput chatCompletionInput) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
*/
package org.opensearch.searchpipelines.questionanswering.generative.llm;

import org.opensearch.core.action.ActionListener;

/**
* Capabilities of large language models, e.g. completion, embeddings, etc.
*/
Expand All @@ -29,5 +31,5 @@ enum ModelProvider {
COHERE
}

ChatCompletionOutput doChatCompletion(ChatCompletionInput input);
void doChatCompletion(ChatCompletionInput input, ActionListener<ChatCompletionOutput> listener);
}
Loading
Loading