Skip to content

Commit

Permalink
refactor memory manager and Get Trace actions
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed Dec 15, 2023
1 parent 9a6823a commit 2a2ccd0
Show file tree
Hide file tree
Showing 12 changed files with 1,061 additions and 70 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.memory.action.conversation;

import org.opensearch.action.ActionType;

/**
* Action to return the traces associated with an interaction
*/
public class GetTracesAction extends ActionType<GetTracesResponse> {
/** Instance of this */
public static final GetTracesAction INSTANCE = new GetTracesAction();
/** Name of this action */
public static final String NAME = "cluster:admin/opensearch/ml/memory/trace/get";

private GetTracesAction() {
super(NAME, GetTracesResponse::new);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.memory.action.conversation;

import static org.opensearch.action.ValidateActions.addValidationError;

import java.io.IOException;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.rest.RestRequest;

import lombok.Getter;

/**
* ActionRequest for get traces
*/
public class GetTracesRequest extends ActionRequest {
@Getter
private String interactionId;
@Getter
private int maxResults = ActionConstants.DEFAULT_MAX_RESULTS;
@Getter
private int from = 0;

/**
* Constructor
* @param interactionId UID of the interaction to get traces from
*/
public GetTracesRequest(String interactionId) {
this.interactionId = interactionId;
}

/**
* Constructor
* @param interactionId UID of the conversation to get interactions from
* @param maxResults number of interactions to retrieve
*/
public GetTracesRequest(String interactionId, int maxResults) {
this.interactionId = interactionId;
this.maxResults = maxResults;
}

/**
* Constructor
* @param interactionId UID of the conversation to get interactions from
* @param maxResults number of interactions to retrieve
* @param from position of first interaction to retrieve
*/
public GetTracesRequest(String interactionId, int maxResults, int from) {
this.interactionId = interactionId;
this.maxResults = maxResults;
this.from = from;
}

/**
* Constructor
* @param in streaminput to read this from. assumes there was a GetTracesRequest.writeTo
* @throws IOException if there wasn't a GIR in the stream
*/
public GetTracesRequest(StreamInput in) throws IOException {
super(in);
this.interactionId = in.readString();
this.maxResults = in.readInt();
this.from = in.readInt();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(interactionId);
out.writeInt(maxResults);
out.writeInt(from);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;
if (interactionId == null) {
exception = addValidationError("Traces must be retrieved from an interaction", exception);
}
if (maxResults <= 0) {
exception = addValidationError("The number of traces to retrieve must be positive", exception);
}
if (from < 0) {
exception = addValidationError("The starting position must be nonnegative", exception);
}

return exception;
}

/**
* Makes a GetTracesRequest out of a RestRequest
* @param request Rest Request representing a get traces request
* @return a new GetTracesRequest
* @throws IOException if something goes wrong
*/
public static GetTracesRequest fromRestRequest(RestRequest request) throws IOException {
String cid = request.param(ActionConstants.RESPONSE_INTERACTION_ID_FIELD);
if (request.hasParam(ActionConstants.NEXT_TOKEN_FIELD)) {
int from = Integer.parseInt(request.param(ActionConstants.NEXT_TOKEN_FIELD));
if (request.hasParam(ActionConstants.REQUEST_MAX_RESULTS_FIELD)) {
int maxResults = Integer.parseInt(request.param(ActionConstants.REQUEST_MAX_RESULTS_FIELD));
return new GetTracesRequest(cid, maxResults, from);
} else {
return new GetTracesRequest(cid, ActionConstants.DEFAULT_MAX_RESULTS, from);
}
} else {
if (request.hasParam(ActionConstants.REQUEST_MAX_RESULTS_FIELD)) {
int maxResults = Integer.parseInt(request.param(ActionConstants.REQUEST_MAX_RESULTS_FIELD));
return new GetTracesRequest(cid, maxResults);
} else {
return new GetTracesRequest(cid);
}
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.memory.action.conversation;

import java.io.IOException;
import java.util.List;

import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.ml.common.conversation.Interaction;

import lombok.AllArgsConstructor;
import lombok.Getter;

/**
* Action Response for get traces for an interaction
*/
@AllArgsConstructor
public class GetTracesResponse extends ActionResponse implements ToXContentObject {
@Getter
private List<Interaction> traces;
@Getter
private int nextToken;
private boolean hasMoreTokens;

/**
* Constructor
* @param in stream input; assumes GetTracesResponse.writeTo was called
* @throws IOException if there's not a G.I.R. in the stream
*/
public GetTracesResponse(StreamInput in) throws IOException {
super(in);
traces = in.readList(Interaction::fromStream);
nextToken = in.readInt();
hasMoreTokens = in.readBoolean();
}

public void writeTo(StreamOutput out) throws IOException {
out.writeList(traces);
out.writeInt(nextToken);
out.writeBoolean(hasMoreTokens);
}

/**
* Are there more pages in this search results
* @return whether there are more traces in this search
*/
public boolean hasMorePages() {
return hasMoreTokens;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.startArray(ActionConstants.RESPONSE_TRACES_LIST_FIELD);
for (Interaction trace : traces) {
trace.toXContent(builder, params);
}
builder.endArray();
if (hasMoreTokens) {
builder.field(ActionConstants.NEXT_TOKEN_FIELD, nextToken);
}
builder.endObject();
return builder;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.memory.action.conversation;

import java.util.List;

import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class GetTracesTransportAction extends HandledTransportAction<GetTracesRequest, GetTracesResponse> {
private Client client;
private ConversationalMemoryHandler cmHandler;

/**
* Constructor
* @param transportService for inter-node communications
* @param actionFilters for filtering actions
* @param cmHandler Handler for conversational memory operations
* @param client OS Client for dealing with OS
* @param clusterService for some cluster ops
*/
@Inject
public GetTracesTransportAction(
TransportService transportService,
ActionFilters actionFilters,
OpenSearchConversationalMemoryHandler cmHandler,
Client client,
ClusterService clusterService
) {
super(GetTracesAction.NAME, transportService, actionFilters, GetTracesRequest::new);
this.client = client;
this.cmHandler = cmHandler;
}

@Override
public void doExecute(Task task, GetTracesRequest request, ActionListener<GetTracesResponse> actionListener) {
int maxResults = request.getMaxResults();
int from = request.getFrom();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
ActionListener<GetTracesResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore());
ActionListener<List<Interaction>> al = ActionListener.wrap(tracesList -> {
internalListener.onResponse(new GetTracesResponse(tracesList, from + maxResults, tracesList.size() == maxResults));
}, e -> { internalListener.onFailure(e); });
cmHandler.getTraces(request.getInteractionId(), from, maxResults, al);
} catch (Exception e) {
log.error("Failed to get traces for conversation " + request.getInteractionId(), e);
actionListener.onFailure(e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.memory.action.conversation;

import java.io.IOException;
import java.util.Map;

import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.BytesStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.rest.RestRequest;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.rest.FakeRestRequest;

public class GetTracesRequestTests extends OpenSearchTestCase {

public void testConstructorsAndStreaming() throws IOException {
GetTracesRequest request = new GetTracesRequest("test-iid");
assert (request.validate() == null);
assert (request.getInteractionId().equals("test-iid"));
assert (request.getFrom() == 0);
assert (request.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS);

GetTracesRequest req2 = new GetTracesRequest("test-iid2", 3);
assert (req2.validate() == null);
assert (req2.getInteractionId().equals("test-iid2"));
assert (req2.getFrom() == 0);
assert (req2.getMaxResults() == 3);

GetTracesRequest req3 = new GetTracesRequest("test-iid3", 4, 5);
assert (req3.validate() == null);
assert (req3.getInteractionId().equals("test-iid3"));
assert (req3.getFrom() == 5);
assert (req3.getMaxResults() == 4);

BytesStreamOutput outbytes = new BytesStreamOutput();
StreamOutput osso = new OutputStreamStreamOutput(outbytes);
request.writeTo(osso);
StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes()));
GetTracesRequest req4 = new GetTracesRequest(in);
assert (req4.validate() == null);
assert (req4.getInteractionId().equals("test-iid"));
assert (req4.getFrom() == 0);
assert (req4.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS);
}

public void testBadValues_thenFail() {
String nullstr = null;
GetTracesRequest request = new GetTracesRequest(nullstr);
assert (request.validate().validationErrors().get(0).equals("Traces must be retrieved from an interaction"));
assert (request.validate().validationErrors().size() == 1);

request = new GetTracesRequest("iid", -2);
assert (request.validate().validationErrors().size() == 1);
assert (request.validate().validationErrors().get(0).equals("The number of traces to retrieve must be positive"));

request = new GetTracesRequest("iid", 2, -2);
assert (request.validate().validationErrors().size() == 1);
assert (request.validate().validationErrors().get(0).equals("The starting position must be nonnegative"));
}

public void testFromRestRequest() throws IOException {
Map<String, String> basic = Map.of(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, "iid1");
Map<String, String> maxResOnly = Map
.of(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, "iid2", ActionConstants.REQUEST_MAX_RESULTS_FIELD, "4");
Map<String, String> nextTokOnly = Map
.of(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, "iid3", ActionConstants.NEXT_TOKEN_FIELD, "6");
Map<String, String> bothFields = Map
.of(
ActionConstants.RESPONSE_INTERACTION_ID_FIELD,
"iid4",
ActionConstants.REQUEST_MAX_RESULTS_FIELD,
"2",
ActionConstants.NEXT_TOKEN_FIELD,
"7"
);
RestRequest req1 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(basic).build();
RestRequest req2 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(maxResOnly).build();
RestRequest req3 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(nextTokOnly).build();
RestRequest req4 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(bothFields).build();
GetTracesRequest gir1 = GetTracesRequest.fromRestRequest(req1);
GetTracesRequest gir2 = GetTracesRequest.fromRestRequest(req2);
GetTracesRequest gir3 = GetTracesRequest.fromRestRequest(req3);
GetTracesRequest gir4 = GetTracesRequest.fromRestRequest(req4);

assert (gir1.validate() == null && gir2.validate() == null && gir3.validate() == null && gir4.validate() == null);
assert (gir1.getInteractionId().equals("iid1") && gir2.getInteractionId().equals("iid2"));
assert (gir3.getInteractionId().equals("iid3") && gir4.getInteractionId().equals("iid4"));
assert (gir1.getFrom() == 0 && gir2.getFrom() == 0 && gir3.getFrom() == 6 && gir4.getFrom() == 7);
assert (gir1.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS && gir2.getMaxResults() == 4);
assert (gir3.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS && gir4.getMaxResults() == 2);
}
}
Loading

0 comments on commit 2a2ccd0

Please sign in to comment.