Skip to content

Commit

Permalink
support regenerate for chatbot (#1816)
Browse files Browse the repository at this point in the history
* support regenerate for chatbot

Signed-off-by: Hailong Cui <[email protected]>

* update test method name

Signed-off-by: Hailong Cui <[email protected]>

* exclude MLModelCacheHelper for jacoco coverage check

Signed-off-by: Hailong Cui <[email protected]>

* Address review comments

Signed-off-by: Hailong Cui <[email protected]>

* Address review comments

Signed-off-by: Hailong Cui <[email protected]>

---------

Signed-off-by: Hailong Cui <[email protected]>
  • Loading branch information
Hailong-am authored Dec 29, 2023
1 parent 42ea8cd commit f30c3cb
Show file tree
Hide file tree
Showing 5 changed files with 347 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
import org.opensearch.ml.engine.memory.ConversationIndexMessage;
import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse;
import org.opensearch.ml.memory.action.conversation.GetInteractionAction;
import org.opensearch.ml.memory.action.conversation.GetInteractionRequest;

import com.google.common.annotations.VisibleForTesting;
import com.google.gson.Gson;
Expand All @@ -62,6 +64,7 @@ public class MLAgentExecutor implements Executable {
public static final String MEMORY_ID = "memory_id";
public static final String QUESTION = "question";
public static final String PARENT_INTERACTION_ID = "parent_interaction_id";
public static final String REGENERATE_INTERACTION_ID = "regenerate_interaction_id";

private Client client;
private Settings settings;
Expand Down Expand Up @@ -113,9 +116,14 @@ public void execute(Input input, ActionListener<Output> listener) {
MLMemorySpec memorySpec = mlAgent.getMemory();
String memoryId = inputDataSet.getParameters().get(MEMORY_ID);
String parentInteractionId = inputDataSet.getParameters().get(PARENT_INTERACTION_ID);
String regenerateInteractionId = inputDataSet.getParameters().get(REGENERATE_INTERACTION_ID);
String appType = mlAgent.getAppType();
String question = inputDataSet.getParameters().get(QUESTION);

if (memoryId == null && regenerateInteractionId != null) {
throw new IllegalArgumentException("A memory ID must be provided to regenerate.");
}

if (memorySpec != null
&& memorySpec.getType() != null
&& memoryFactoryMap.containsKey(memorySpec.getType())
Expand All @@ -124,28 +132,27 @@ public void execute(Input input, ActionListener<Output> listener) {
(ConversationIndexMemory.Factory) memoryFactoryMap.get(memorySpec.getType());
conversationIndexMemoryFactory.create(question, memoryId, appType, ActionListener.wrap(memory -> {
inputDataSet.getParameters().put(MEMORY_ID, memory.getConversationId());
// Create root interaction ID
ConversationIndexMessage msg = ConversationIndexMessage
.conversationIndexMessageBuilder()
.type(appType)
.question(question)
.response("")
.finalAnswer(true)
.sessionId(memory.getConversationId())
.build();
memory.save(msg, null, null, null, ActionListener.<CreateInteractionResponse>wrap(interaction -> {
log.info("Created parent interaction ID: " + interaction.getId());
inputDataSet.getParameters().put(PARENT_INTERACTION_ID, interaction.getId());
ActionListener<Object> agentActionListener = createAgentActionListener(
listener,
outputs,
modelTensors
);
executeAgent(inputDataSet, mlAgent, agentActionListener);
}, ex -> {
log.error("Failed to create parent interaction", ex);
listener.onFailure(ex);
}));
ActionListener<Object> agentActionListener = createAgentActionListener(listener, outputs, modelTensors);
// get question for regenerate
if (regenerateInteractionId != null) {
log.info("Regenerate for existing interaction {}", regenerateInteractionId);
client
.execute(
GetInteractionAction.INSTANCE,
new GetInteractionRequest(memoryId, regenerateInteractionId),
ActionListener.wrap(interactionRes -> {
inputDataSet
.getParameters()
.putIfAbsent(QUESTION, interactionRes.getInteraction().getInput());
saveRootInteractionAndExecute(agentActionListener, memory, inputDataSet, mlAgent);
}, e -> {
log.error("Failed to get existing interaction for regeneration", e);
listener.onFailure(e);
})
);
} else {
saveRootInteractionAndExecute(agentActionListener, memory, inputDataSet, mlAgent);
}
}, ex -> {
log.error("Failed to read conversation memory", ex);
listener.onFailure(ex);
Expand All @@ -167,6 +174,54 @@ public void execute(Input input, ActionListener<Output> listener) {

}

/**
* save root interaction and start execute the agent
* @param listener callback listener
* @param memory memory instance
* @param inputDataSet input
* @param mlAgent agent to run
*/
private void saveRootInteractionAndExecute(
ActionListener<Object> listener,
ConversationIndexMemory memory,
RemoteInferenceInputDataSet inputDataSet,
MLAgent mlAgent
) {
String appType = mlAgent.getAppType();
String question = inputDataSet.getParameters().get(QUESTION);
String regenerateInteractionId = inputDataSet.getParameters().get(REGENERATE_INTERACTION_ID);
// Create root interaction ID
ConversationIndexMessage msg = ConversationIndexMessage
.conversationIndexMessageBuilder()
.type(appType)
.question(question)
.response("")
.finalAnswer(true)
.sessionId(memory.getConversationId())
.build();
memory.save(msg, null, null, null, ActionListener.<CreateInteractionResponse>wrap(interaction -> {
log.info("Created parent interaction ID: " + interaction.getId());
inputDataSet.getParameters().put(PARENT_INTERACTION_ID, interaction.getId());
// only delete previous interaction when new interaction created
if (regenerateInteractionId != null) {
memory
.getMemoryManager()
.deleteInteractionAndTrace(
regenerateInteractionId,
ActionListener.wrap(deleted -> executeAgent(inputDataSet, mlAgent, listener), e -> {
log.error("Failed to regenerate for interaction {}", regenerateInteractionId, e);
listener.onFailure(e);
})
);
} else {
executeAgent(inputDataSet, mlAgent, listener);
}
}, ex -> {
log.error("Failed to create parent interaction", ex);
listener.onFailure(ex);
}));
}

private void executeAgent(RemoteInferenceInputDataSet inputDataSet, MLAgent mlAgent, ActionListener<Object> agentActionListener) {
MLAgentRunner mlAgentRunner = getAgentRunner(mlAgent);
mlAgentRunner.run(mlAgent, inputDataSet.getParameters(), agentActionListener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,13 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.ExistsQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.index.reindex.BulkByScrollResponse;
import org.opensearch.index.reindex.DeleteByQueryAction;
import org.opensearch.index.reindex.DeleteByQueryRequest;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.ml.memory.action.conversation.CreateConversationAction;
import org.opensearch.ml.memory.action.conversation.CreateConversationRequest;
Expand Down Expand Up @@ -233,4 +238,62 @@ public void updateInteraction(String interactionId, Map<String, Object> updateCo
actionListener.onFailure(exception);
}
}

/**
* Delete interaction and its trace data
* @param interactionId interaction id
* @param listener callback for delete result
*/
public void deleteInteractionAndTrace(String interactionId, ActionListener<Boolean> listener) {
DeleteByQueryRequest deleteByQueryRequest = new DeleteByQueryRequest(INTERACTIONS_INDEX_NAME);
deleteByQueryRequest.setQuery(buildDeleteInteractionQuery(interactionId));
deleteByQueryRequest.setRefresh(true);

innerDeleteInteractionAndTrace(deleteByQueryRequest, interactionId, listener);
}

@VisibleForTesting
void innerDeleteInteractionAndTrace(DeleteByQueryRequest deleteByQueryRequest, String interactionId, ActionListener<Boolean> listener) {
try (ThreadContext.StoredContext ignored = client.threadPool().getThreadContext().stashContext()) {
ActionListener<BulkByScrollResponse> al = ActionListener.wrap(bulkResponse -> {
if (bulkResponse != null && (!bulkResponse.getBulkFailures().isEmpty() || !bulkResponse.getSearchFailures().isEmpty())) {
log.info("Failed to delete the interaction with ID: {}", interactionId);
listener.onResponse(false);
return;
}
log.info("Successfully delete the interaction with ID: {}", interactionId);
listener.onResponse(true);
}, exception -> {
log.error("Failed to delete interaction with ID {}. Details: {}", interactionId, exception);
listener.onFailure(exception);
});
// bulk delete interaction and its trace
client.execute(DeleteByQueryAction.INSTANCE, deleteByQueryRequest, al);
} catch (Exception e) {
log.error("Failed to delete interaction with ID {}. Details {}:", interactionId, e);
listener.onFailure(e);
}
}

@VisibleForTesting
QueryBuilder buildDeleteInteractionQuery(String interactionId) {
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();
// interaction itself
boolQueryBuilder.should(QueryBuilders.idsQuery().addIds(interactionId));

// Build the trace query
BoolQueryBuilder traceBoolBuilder = QueryBuilders.boolQuery();
// Add the ExistsQueryBuilder for checking null values
ExistsQueryBuilder existsQueryBuilder = QueryBuilders.existsQuery(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD);
traceBoolBuilder.must(existsQueryBuilder);

// Add the TermQueryBuilder for another field
TermQueryBuilder termQueryBuilder = QueryBuilders
.termQuery(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, interactionId);
traceBoolBuilder.must(termQueryBuilder);

// interaction trace
boolQueryBuilder.should(traceBoolBuilder);
return boolQueryBuilder;
}
}
Loading

0 comments on commit f30c3cb

Please sign in to comment.