Skip to content

Commit

Permalink
add new fields in the memory and refactor transport actions
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed Nov 12, 2023
1 parent 6efb1eb commit 28c266e
Show file tree
Hide file tree
Showing 23 changed files with 658 additions and 166 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ public class ActionConstants {
public final static String PROMPT_TEMPLATE_FIELD = "prompt_template";
/** name of metadata field in all requests */
public final static String ADDITIONAL_INFO_FIELD = "additional_info";
/** name of metadata field in all requests */
public final static String INTERACTION_ID_FIELD = "interaction_id";
/** name of metadata field in all requests */
public final static String TRACE_NUMBER_FIELD = "trace_number";
/** name of success field in all requests */
public final static String SUCCESS_FIELD = "success";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ public class ConversationalIndexConstants {
public final static String META_NAME_FIELD = "name";
/** Name of the owning user field in all indices */
public final static String USER_FIELD = "user";
/** Name of the application that created this conversation */
public final static String APPLICATION_TYPE_FIELD = "application_type";
/** Mappings for the conversational metadata index */
public final static String META_MAPPING = "{\n"
+ " \"_meta\": {\n"
Expand All @@ -47,6 +49,9 @@ public class ConversationalIndexConstants {
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n"
+ " \""
+ USER_FIELD
+ "\": {\"type\": \"keyword\"},\n"
+ " \""
+ APPLICATION_TYPE_FIELD
+ "\": {\"type\": \"keyword\"}\n"
+ " }\n"
+ "}";
Expand All @@ -69,6 +74,10 @@ public class ConversationalIndexConstants {
public final static String INTERACTIONS_ADDITIONAL_INFO_FIELD = "additional_info";
/** Name of the interaction field for the timestamp */
public final static String INTERACTIONS_CREATE_TIME_FIELD = "create_time";
/** Name of the interaction id */
public final static String INTERACTIONS_ID_FIELD = "interaction_id";
/** The trace number of an interaction */
public final static String INTERACTIONS_TRACE_NUMBER_FIELD = "trace_number";
/** Mappings for the interactions index */
public final static String INTERACTIONS_MAPPINGS = "{\n"
+ " \"_meta\": {\n"
Expand All @@ -95,7 +104,13 @@ public class ConversationalIndexConstants {
+ "\": {\"type\": \"keyword\"},\n"
+ " \""
+ INTERACTIONS_ADDITIONAL_INFO_FIELD
+ "\": {\"type\": \"text\"}\n"
+ "\": {\"type\": \"flat_object\"},\n"
+ " \""
+ INTERACTIONS_ID_FIELD
+ "\": {\"type\": \"keyword\"},\n"
+ " \""
+ INTERACTIONS_TRACE_NUMBER_FIELD
+ "\": {\"type\": \"long\"}\n"
+ " }\n"
+ "}";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.io.IOException;
import java.time.Instant;
import java.util.HashMap;
import java.util.Map;

import org.opensearch.core.common.io.stream.StreamInput;
Expand Down Expand Up @@ -54,7 +55,7 @@ public class Interaction implements Writeable, ToXContentObject {
@Getter
private String origin;
@Getter
private String additionalInfo;
private Map<String, String> additionalInfo;

/**
* Creates an Interaction object from a map of fields in the OS index
Expand All @@ -69,7 +70,7 @@ public static Interaction fromMap(String id, Map<String, Object> fields) {
String promptTemplate = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD);
String response = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD);
String origin = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD);
String additionalInfo = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD);
Map<String,String> additionalInfo = (Map<String,String>) fields.get(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD);
return new Interaction(id, createTime, conversationId, input, promptTemplate, response, origin, additionalInfo);
}

Expand Down Expand Up @@ -97,7 +98,10 @@ public static Interaction fromStream(StreamInput in) throws IOException {
String promptTemplate = in.readString();
String response = in.readString();
String origin = in.readString();
String additionalInfo = in.readOptionalString();
Map<String, String> additionalInfo = new HashMap<>();
if (in.readBoolean()) {
additionalInfo = in.readMap(s -> s.readString(), s -> s.readString());
}
return new Interaction(id, createTime, conversationId, input, promptTemplate, response, origin, additionalInfo);
}

Expand All @@ -111,7 +115,12 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(promptTemplate);
out.writeString(response);
out.writeString(origin);
out.writeOptionalString(additionalInfo);
if (additionalInfo != null) {
out.writeBoolean(true);
out.writeMap(additionalInfo, StreamOutput::writeString, StreamOutput::writeString);
} else {
out.writeBoolean(false);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.opensearch.ml.memory;

import java.util.List;
import java.util.Map;

import org.opensearch.common.action.ActionFuture;
import org.opensearch.core.action.ActionListener;
Expand Down Expand Up @@ -49,6 +50,14 @@ public interface ConversationalMemoryHandler {
*/
public void createConversation(String name, ActionListener<String> listener);

/**
* Create a new conversation
* @param name the name of the new conversation
* @param applicationType the application that creates this conversation
* @param listener listener to wait for this op to finish, gets unique id of new conversation
*/
public void createConversation(String name, String applicationType, ActionListener<String> listener);

/**
* Create a new conversation
* @param name the name of the new conversation
Expand All @@ -72,10 +81,34 @@ public void createInteraction(
String promptTemplate,
String response,
String origin,
String additionalInfo,
Map<String, String> additionalInfo,
ActionListener<String> listener
);

/**
* Adds an interaction to the conversation indicated, updating the conversational metadata
* @param conversationId the conversation to add the interaction to
* @param input the human input for the interaction
* @param promptTemplate the prompt template used for this interaction
* @param response the Gen AI response for this interaction
* @param origin the name of the GenAI agent in this interaction
* @param additionalInfo additional information used in constructing the LLM prompt
* @param interactionId the parent interactionId of this interaction
* @param traceNumber the trace number for a parent interaction
* @param listener gets the ID of the new interaction
*/
public void createInteraction(
String conversationId,
String input,
String promptTemplate,
String response,
String origin,
Map<String, String> additionalInfo,
ActionListener<String> listener,
String interactionId,
Integer traceNumber
);

/**
* Adds an interaction to the conversation indicated, updating the conversational metadata
* @param conversationId the conversation to add the interaction to
Expand All @@ -92,7 +125,7 @@ public ActionFuture<String> createInteraction(
String promptTemplate,
String response,
String origin,
String additionalInfo
Map<String, String> additionalInfo
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
public class CreateConversationRequest extends ActionRequest {
@Getter
private String name = null;
@Getter
private String applicationType = null;

/**
* Constructor
Expand All @@ -44,6 +46,7 @@ public class CreateConversationRequest extends ActionRequest {
public CreateConversationRequest(StreamInput in) throws IOException {
super(in);
this.name = in.readOptionalString();
this.applicationType = in.readOptionalString();
}

/**
Expand All @@ -55,6 +58,16 @@ public CreateConversationRequest(String name) {
this.name = name;
}

/**
* Constructor
* @param name name of the conversation
*/
public CreateConversationRequest(String name, String applicationType) {
super();
this.name = name;
this.applicationType = applicationType;
}

/**
* Constructor
* name will be null
Expand All @@ -65,6 +78,7 @@ public CreateConversationRequest() {}
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeOptionalString(name);
out.writeOptionalString(applicationType);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ protected void doExecute(Task task, CreateConversationRequest request, ActionLis
return;
}
String name = request.getName();
String applicationType = request.getApplicationType();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
ActionListener<CreateConversationResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore());
ActionListener<String> al = ActionListener.wrap(r -> { internalListener.onResponse(new CreateConversationResponse(r)); }, e -> {
Expand All @@ -92,7 +93,7 @@ protected void doExecute(Task task, CreateConversationRequest request, ActionLis
if (name == null) {
cmHandler.createConversation(al);
} else {
cmHandler.createConversation(name, al);
cmHandler.createConversation(name, applicationType, al);
}
} catch (Exception e) {
log.error("Failed to create new conversation with name " + request.getName(), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@
package org.opensearch.ml.memory.action.conversation;

import static org.opensearch.action.ValidateActions.addValidationError;
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.rest.RestRequest;

Expand All @@ -48,7 +52,27 @@ public class CreateInteractionRequest extends ActionRequest {
@Getter
private String origin;
@Getter
private String additionalInfo;
private Map<String, String> additionalInfo;
@Getter
private String interaction_id;
@Getter
private Integer trace_number;

public CreateInteractionRequest(
String conversationId,
String input,
String promptTemplate,
String response,
String origin,
Map<String, String> additionalInfo
) {
this.conversationId = conversationId;
this.input = input;
this.promptTemplate = promptTemplate;
this.response = response;
this.origin = origin;
this.additionalInfo = additionalInfo;
}

/**
* Constructor
Expand All @@ -62,7 +86,11 @@ public CreateInteractionRequest(StreamInput in) throws IOException {
this.promptTemplate = in.readString();
this.response = in.readString();
this.origin = in.readOptionalString();
this.additionalInfo = in.readOptionalString();
if (in.readBoolean()) {
this.additionalInfo = in.readMap(s -> s.readString(), s -> s.readString());
}
this.interaction_id = in.readOptionalString();
this.trace_number = in.readOptionalInt();
}

@Override
Expand All @@ -73,7 +101,14 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(promptTemplate);
out.writeString(response);
out.writeOptionalString(origin);
out.writeOptionalString(additionalInfo);
if (additionalInfo != null) {
out.writeBoolean(true);
out.writeMap(additionalInfo, StreamOutput::writeString, StreamOutput::writeString);
} else {
out.writeBoolean(false);
}
out.writeOptionalString(interaction_id);
out.writeOptionalInt(trace_number);
}

@Override
Expand All @@ -92,14 +127,55 @@ public ActionRequestValidationException validate() {
* @throws IOException if something goes wrong reading from request
*/
public static CreateInteractionRequest fromRestRequest(RestRequest request) throws IOException {
Map<String, String> body = request.contentParser().mapStrings();
String cid = request.param(ActionConstants.CONVERSATION_ID_FIELD);
String inp = body.get(ActionConstants.INPUT_FIELD);
String prmpt = body.get(ActionConstants.PROMPT_TEMPLATE_FIELD);
String rsp = body.get(ActionConstants.AI_RESPONSE_FIELD);
String ogn = body.get(ActionConstants.RESPONSE_ORIGIN_FIELD);
String addinf = body.get(ActionConstants.ADDITIONAL_INFO_FIELD);
return new CreateInteractionRequest(cid, inp, prmpt, rsp, ogn, addinf);
XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);

String cid = null;
String input = null;
String prompt = null;
String rep = null;
String origin = null;
Map<String, String> addinf = new HashMap<>();
String interactionid = null;
Integer tracenum = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case ActionConstants.CONVERSATION_ID_FIELD:
cid = parser.text();
break;
case ActionConstants.INPUT_FIELD:
input = parser.text();
break;
case ActionConstants.PROMPT_TEMPLATE_FIELD:
prompt = parser.text();
break;
case ActionConstants.AI_RESPONSE_FIELD:
rep = parser.text();
break;
case ActionConstants.RESPONSE_ORIGIN_FIELD:
origin = parser.text();
break;
case ActionConstants.ADDITIONAL_INFO_FIELD:
addinf = getParameterMap(parser.map());
break;
case ActionConstants.INTERACTION_ID_FIELD:
interactionid = parser.text();
break;
case ActionConstants.TRACE_NUMBER_FIELD:
tracenum = parser.intValue(false);
break;
default:
parser.skipChildren();
break;
}
}

return new CreateInteractionRequest(cid, input, prompt, rep, origin, addinf, interactionid, tracenum);
}

}
Loading

0 comments on commit 28c266e

Please sign in to comment.