-
Notifications
You must be signed in to change notification settings - Fork 138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Migrate RAG pipeline to async processing. #2345
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,20 +27,23 @@ | |
import java.util.Map; | ||
import java.util.function.BooleanSupplier; | ||
|
||
import org.opensearch.OpenSearchException; | ||
import org.opensearch.action.search.SearchRequest; | ||
import org.opensearch.action.search.SearchResponse; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.common.Strings; | ||
import org.opensearch.ingest.ConfigurationUtils; | ||
import org.opensearch.ml.common.conversation.Interaction; | ||
import org.opensearch.ml.common.exception.MLException; | ||
import org.opensearch.search.SearchHit; | ||
import org.opensearch.search.pipeline.AbstractProcessor; | ||
import org.opensearch.search.pipeline.PipelineProcessingContext; | ||
import org.opensearch.search.pipeline.Processor; | ||
import org.opensearch.search.pipeline.SearchResponseProcessor; | ||
import org.opensearch.searchpipelines.questionanswering.generative.client.ConversationalMemoryClient; | ||
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamUtil; | ||
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters; | ||
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionInput; | ||
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput; | ||
import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm; | ||
import org.opensearch.searchpipelines.questionanswering.generative.llm.LlmIOUtil; | ||
|
@@ -65,8 +68,6 @@ public class GenerativeQAResponseProcessor extends AbstractProcessor implements | |
|
||
private static final int DEFAULT_PROCESSOR_TIME_IN_SECONDS = 30; | ||
|
||
// TODO Add "interaction_count". This is how far back in chat history we want to go back when calling LLM. | ||
|
||
private final String llmModel; | ||
private final List<String> contextFields; | ||
|
||
|
@@ -106,20 +107,32 @@ protected GenerativeQAResponseProcessor( | |
} | ||
|
||
@Override | ||
public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception { | ||
public SearchResponse processResponse(SearchRequest searchRequest, SearchResponse searchResponse) { | ||
// Synchronous call is no longer supported because this execution can occur on a transport thread. | ||
throw new UnsupportedOperationException(); | ||
} | ||
|
||
log.info("Entering processResponse."); | ||
@Override | ||
public void processResponseAsync( | ||
SearchRequest request, | ||
SearchResponse response, | ||
PipelineProcessingContext requestContext, | ||
ActionListener<SearchResponse> responseListener | ||
) { | ||
log.debug("Entering processResponse."); | ||
|
||
if (!this.featureFlagSupplier.getAsBoolean()) { | ||
throw new MLException(GenerativeQAProcessorConstants.FEATURE_NOT_ENABLED_ERROR_MSG); | ||
} | ||
|
||
GenerativeQAParameters params = GenerativeQAParamUtil.getGenerativeQAParameters(request); | ||
|
||
Integer timeout = params.getTimeout(); | ||
if (timeout == null || timeout == GenerativeQAParameters.SIZE_NULL_VALUE) { | ||
timeout = DEFAULT_PROCESSOR_TIME_IN_SECONDS; | ||
Integer t = params.getTimeout(); | ||
if (t == null || t == GenerativeQAParameters.SIZE_NULL_VALUE) { | ||
t = DEFAULT_PROCESSOR_TIME_IN_SECONDS; | ||
} | ||
final int timeout = t; | ||
log.debug("Timeout for this request: {} seconds.", timeout); | ||
|
||
String llmQuestion = params.getLlmQuestion(); | ||
String llmModel = params.getLlmModel() == null ? this.llmModel : params.getLlmModel(); | ||
|
@@ -128,14 +141,15 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp | |
} | ||
String conversationId = params.getConversationId(); | ||
|
||
if (conversationId != null && !Strings.hasText(conversationId)) { | ||
throw new IllegalArgumentException("Empty conversation_id is not allowed."); | ||
} | ||
Instant start = Instant.now(); | ||
Integer interactionSize = params.getInteractionSize(); | ||
if (interactionSize == null || interactionSize == GenerativeQAParameters.SIZE_NULL_VALUE) { | ||
interactionSize = DEFAULT_CHAT_HISTORY_WINDOW; | ||
} | ||
List<Interaction> chatHistory = (conversationId == null) | ||
? Collections.emptyList() | ||
: memoryClient.getInteractions(conversationId, interactionSize); | ||
log.debug("Using interaction size of {}", interactionSize); | ||
|
||
Integer topN = params.getContextSize(); | ||
if (topN == null) { | ||
|
@@ -153,10 +167,32 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp | |
effectiveUserInstructions = params.getUserInstructions(); | ||
} | ||
|
||
start = Instant.now(); | ||
try { | ||
ChatCompletionOutput output = llm | ||
.doChatCompletion( | ||
final List<Interaction> chatHistory = new ArrayList<>(); | ||
if (conversationId == null) { | ||
doChatCompletion( | ||
LlmIOUtil | ||
.createChatCompletionInput( | ||
systemPrompt, | ||
userInstructions, | ||
llmModel, | ||
llmQuestion, | ||
chatHistory, | ||
searchResults, | ||
timeout, | ||
params.getLlmResponseField() | ||
), | ||
null, | ||
llmQuestion, | ||
searchResults, | ||
response, | ||
responseListener | ||
); | ||
} else { | ||
final Instant memoryStart = Instant.now(); | ||
memoryClient.getInteractions(conversationId, interactionSize, ActionListener.wrap(r -> { | ||
log.debug("getInteractions complete. ({})", getDuration(memoryStart)); | ||
chatHistory.addAll(r); | ||
doChatCompletion( | ||
LlmIOUtil | ||
.createChatCompletionInput( | ||
systemPrompt, | ||
|
@@ -167,53 +203,82 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp | |
searchResults, | ||
timeout, | ||
params.getLlmResponseField() | ||
) | ||
), | ||
conversationId, | ||
llmQuestion, | ||
searchResults, | ||
response, | ||
responseListener | ||
); | ||
log.info("doChatCompletion complete. ({})", getDuration(start)); | ||
}, responseListener::onFailure)); | ||
} | ||
} | ||
|
||
String answer = null; | ||
String errorMessage = null; | ||
String interactionId = null; | ||
if (output.isErrorOccurred()) { | ||
errorMessage = output.getErrors().get(0); | ||
} else { | ||
answer = (String) output.getAnswers().get(0); | ||
private void doChatCompletion( | ||
ChatCompletionInput input, | ||
String conversationId, | ||
String llmQuestion, | ||
List<String> searchResults, | ||
SearchResponse response, | ||
ActionListener<SearchResponse> responseListener | ||
) { | ||
|
||
final Instant chatStart = Instant.now(); | ||
llm.doChatCompletion(input, new ActionListener<>() { | ||
@Override | ||
public void onResponse(ChatCompletionOutput output) { | ||
log.debug("doChatCompletion complete. ({})", getDuration(chatStart)); | ||
|
||
final String answer = getAnswer(output); | ||
final String errorMessage = getError(output); | ||
|
||
if (conversationId != null) { | ||
start = Instant.now(); | ||
interactionId = memoryClient | ||
final Instant memoryStart = Instant.now(); | ||
memoryClient | ||
.createInteraction( | ||
conversationId, | ||
llmQuestion, | ||
PromptUtil.getPromptTemplate(systemPrompt, userInstructions), | ||
answer, | ||
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, | ||
Collections.singletonMap("metadata", jsonArrayToString(searchResults)) | ||
Collections.singletonMap("metadata", jsonArrayToString(searchResults)), | ||
ActionListener.wrap(r -> { | ||
responseListener.onResponse(insertAnswer(response, answer, errorMessage, r)); | ||
log.info("Created a new interaction: {} ({})", r, getDuration(memoryStart)); | ||
}, responseListener::onFailure) | ||
); | ||
log.info("Created a new interaction: {} ({})", interactionId, getDuration(start)); | ||
|
||
} else { | ||
responseListener.onResponse(insertAnswer(response, answer, errorMessage, null)); | ||
} | ||
|
||
} | ||
|
||
return insertAnswer(response, answer, errorMessage, interactionId); | ||
} catch (NullPointerException nullPointerException) { | ||
throw new IllegalArgumentException(IllegalArgumentMessage); | ||
} catch (Exception e) { | ||
throw new OpenSearchException("GenerativeQAResponseProcessor failed in precessing response"); | ||
} | ||
} | ||
@Override | ||
public void onFailure(Exception e) { | ||
responseListener.onFailure(e); | ||
} | ||
|
||
long getDuration(Instant start) { | ||
return Duration.between(start, Instant.now()).toMillis(); | ||
private String getError(ChatCompletionOutput output) { | ||
return output.isErrorOccurred() ? output.getErrors().get(0) : null; | ||
} | ||
|
||
private String getAnswer(ChatCompletionOutput output) { | ||
return output.isErrorOccurred() ? null : (String) output.getAnswers().get(0); | ||
} | ||
}); | ||
} | ||
|
||
@Override | ||
public String getType() { | ||
return GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE; | ||
} | ||
|
||
private SearchResponse insertAnswer(SearchResponse response, String answer, String errorMessage, String interactionId) { | ||
private long getDuration(Instant start) { | ||
return Duration.between(start, Instant.now()).toMillis(); | ||
} | ||
|
||
// TODO return the interaction id in the response. | ||
private SearchResponse insertAnswer(SearchResponse response, String answer, String errorMessage, String interactionId) { | ||
|
||
return new GenerativeSearchResponse( | ||
answer, | ||
|
@@ -240,9 +305,7 @@ private List<String> getSearchResults(SearchResponse response, Integer topN) { | |
for (String contextField : contextFields) { | ||
Object context = docSourceMap.get(contextField); | ||
if (context == null) { | ||
log.error("Context " + contextField + " not found in search hit " + hits[i]); | ||
// TODO throw a more meaningful error here? | ||
throw new RuntimeException(); | ||
throw new RuntimeException("Context " + contextField + " not found in search hit " + hits[i]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly, you need to make sure that this exception gets propagated to the listener. (I don't remember if that's covered by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll check and also test. |
||
} | ||
searchResults.add(context.toString()); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you managed to test this?
You should probably invoke
responseListener.onFailure()
. Otherwise, the current thread may throw and the listener would sit there waiting for a response.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I have a test for this, but it does not go through the REST layer. I may need an IT test.