Skip to content

Commit

Permalink
Add Memory class in the plugin for agents
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed Nov 13, 2023
1 parent 283dd12 commit 3a001cf
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ public class Interaction implements Writeable, ToXContentObject {
private String origin;
@Getter
private Map<String, String> additionalInfo;
@Getter
private String parentInteractionId;
@Getter
private Integer traceNum;

/**
* Creates an Interaction object from a map of fields in the OS index
Expand All @@ -71,7 +75,9 @@ public static Interaction fromMap(String id, Map<String, Object> fields) {
String response = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD);
String origin = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD);
Map<String,String> additionalInfo = (Map<String,String>) fields.get(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD);
return new Interaction(id, createTime, conversationId, input, promptTemplate, response, origin, additionalInfo);
String parentInteractionId = (String) fields.getOrDefault(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, null);
Integer traceNum = (Integer) fields.getOrDefault(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, null);
return new Interaction(id, createTime, conversationId, input, promptTemplate, response, origin, additionalInfo, parentInteractionId, traceNum);
}

/**
Expand Down Expand Up @@ -102,7 +108,9 @@ public static Interaction fromStream(StreamInput in) throws IOException {
if (in.readBoolean()) {
additionalInfo = in.readMap(s -> s.readString(), s -> s.readString());
}
return new Interaction(id, createTime, conversationId, input, promptTemplate, response, origin, additionalInfo);
String parentInteractionId = in.readOptionalString();
Integer traceNum = in.readOptionalInt();
return new Interaction(id, createTime, conversationId, input, promptTemplate, response, origin, additionalInfo, parentInteractionId, traceNum);
}


Expand All @@ -121,6 +129,8 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeOptionalString(parentInteractionId);
out.writeOptionalInt(traceNum);
}

@Override
Expand All @@ -136,6 +146,12 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para
if(additionalInfo != null) {
builder.field(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, additionalInfo);
}
if (parentInteractionId != null) {
builder.field(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, parentInteractionId);
}
if (traceNum != null) {
builder.field(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, traceNum);
}
builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ public class CreateInteractionRequest extends ActionRequest {
@Getter
private Map<String, String> additionalInfo;
@Getter
private String parent_interaction_id;
private String parentInteractionId;
@Getter
private Integer trace_number;
private Integer traceNum;

public CreateInteractionRequest(
String conversationId,
Expand Down Expand Up @@ -89,8 +89,8 @@ public CreateInteractionRequest(StreamInput in) throws IOException {
if (in.readBoolean()) {
this.additionalInfo = in.readMap(s -> s.readString(), s -> s.readString());
}
this.parent_interaction_id = in.readOptionalString();
this.trace_number = in.readOptionalInt();
this.parentInteractionId = in.readOptionalString();
this.traceNum = in.readOptionalInt();
}

@Override
Expand All @@ -107,8 +107,8 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeOptionalString(parent_interaction_id);
out.writeOptionalInt(trace_number);
out.writeOptionalString(parentInteractionId);
out.writeOptionalInt(traceNum);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,18 @@ protected void doExecute(Task task, CreateInteractionRequest request, ActionList
String ogn = request.getOrigin();
String prompt = request.getPromptTemplate();
Map<String, String> additionalInfo = request.getAdditionalInfo();
String parintid = request.getParent_interaction_id();
Integer traceNumber = request.getTrace_number();
String parentId = request.getParentInteractionId();
Integer traceNumber = request.getTraceNum();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
ActionListener<CreateInteractionResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore());
ActionListener<String> al = ActionListener
.wrap(iid -> { internalListener.onResponse(new CreateInteractionResponse(iid)); }, e -> {
internalListener.onFailure(e);
});
if (parintid == null || traceNumber == null) {
if (parentId == null || traceNumber == null) {
cmHandler.createInteraction(cid, inp, prompt, rsp, ogn, additionalInfo, al);
} else {
cmHandler.createInteraction(cid, inp, prompt, rsp, ogn, additionalInfo, al, parintid, traceNumber);
cmHandler.createInteraction(cid, inp, prompt, rsp, ogn, additionalInfo, al, parentId, traceNumber);
}
} catch (Exception e) {
log.error("Failed to create interaction for conversation " + cid, e);
Expand Down
104 changes: 104 additions & 0 deletions plugin/src/main/java/org/opensearch/ml/memory/MLMemoryManager.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.memory;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.opensearch.client.Client;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.ml.memory.action.conversation.CreateConversationAction;
import org.opensearch.ml.memory.action.conversation.CreateConversationRequest;
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;
import org.opensearch.ml.memory.action.conversation.CreateInteractionAction;
import org.opensearch.ml.memory.action.conversation.CreateInteractionRequest;
import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse;
import org.opensearch.ml.memory.action.conversation.GetInteractionsAction;
import org.opensearch.ml.memory.action.conversation.GetInteractionsRequest;
import org.opensearch.ml.memory.action.conversation.GetInteractionsResponse;
import org.opensearch.ml.repackage.com.google.common.base.Preconditions;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/**
* Memory manager for Memories. It contains ML memory related operations like create, read interactions etc.
*/
@Log4j2
@AllArgsConstructor
public class MLMemoryManager {
public static final int DEFAULT_TIMEOUT_IN_MILLIS = 5000;

private Client client;


public String createConversation(String name, String applicationType) {

CreateConversationResponse response = client
.execute(CreateConversationAction.INSTANCE, new CreateConversationRequest(name, applicationType))
.actionGet(DEFAULT_TIMEOUT_IN_MILLIS);
log.info("createConversation: id: {}", response.getId());
return response.getId();
}

public String createInteraction(
String conversationId,
String input,
String promptTemplate,
String response,
String origin,
Map<String, String> additionalInfo,
String parentIntId,
Integer traceNum
) {
Preconditions.checkNotNull(conversationId);
Preconditions.checkNotNull(input);
Preconditions.checkNotNull(response);
CreateInteractionResponse res = client
.execute(
CreateInteractionAction.INSTANCE,
new CreateInteractionRequest(conversationId, input, promptTemplate, response, origin, additionalInfo, parentIntId, traceNum)
)
.actionGet(DEFAULT_TIMEOUT_IN_MILLIS);
log.info("createInteraction: interactionId: {}", res.getId());
return res.getId();
}

public List<Interaction> getInteractions(String conversationId, int lastNInteraction) {

Preconditions.checkArgument(lastNInteraction > 0, "lastN must be at least 1.");

log.info("Getting Interactions, conversationId {}, lastN {}", conversationId, lastNInteraction);

List<Interaction> interactions = new ArrayList<>();
int from = 0;
boolean allInteractionsFetched = false;
int maxResults = lastNInteraction;
do {
GetInteractionsResponse response = client
.execute(GetInteractionsAction.INSTANCE, new GetInteractionsRequest(conversationId, maxResults, from))
.actionGet(DEFAULT_TIMEOUT_IN_MILLIS);
List<Interaction> list = response.getInteractions();
if (list != null && !CollectionUtils.isEmpty(list)) {
interactions.addAll(list);
from += list.size();
maxResults -= list.size();
log.info("Interactions: {}, from: {}, maxResults: {}", interactions, from, maxResults);
} else if (response.hasMorePages()) {
// If we didn't get any results back, we ignore this flag and break out of the loop
// to avoid an infinite loop.
// But in the future, we may support this mode, e.g. DynamoDB.
break;
}
log.info("Interactions: {}, from: {}, maxResults: {}", interactions, from, maxResults);
allInteractionsFetched = !response.hasMorePages();
} while (from < lastNInteraction && !allInteractionsFetched);

return interactions;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
PromptUtil.getPromptTemplate(systemPrompt, userInstructions),
answer,
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
Collections.singletonMap("metadata", jsonArrayToString(searchResults))
Collections.singletonMap("search_results", jsonArrayToString(searchResults))
);
log.info("Created a new interaction: {} ({})", interactionId, getDuration(start));
}
Expand Down

0 comments on commit 3a001cf

Please sign in to comment.