forked from opensearch-project/ml-commons
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor memory manager and Get Trace actions
Signed-off-by: Xun Zhang <[email protected]>
- Loading branch information
1 parent
9a6823a
commit 2a2ccd0
Showing
12 changed files
with
1,061 additions
and
70 deletions.
There are no files selected for viewing
23 changes: 23 additions & 0 deletions
23
memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.memory.action.conversation; | ||
|
||
import org.opensearch.action.ActionType; | ||
|
||
/** | ||
* Action to return the traces associated with an interaction | ||
*/ | ||
public class GetTracesAction extends ActionType<GetTracesResponse> { | ||
/** Instance of this */ | ||
public static final GetTracesAction INSTANCE = new GetTracesAction(); | ||
/** Name of this action */ | ||
public static final String NAME = "cluster:admin/opensearch/ml/memory/trace/get"; | ||
|
||
private GetTracesAction() { | ||
super(NAME, GetTracesResponse::new); | ||
} | ||
|
||
} |
124 changes: 124 additions & 0 deletions
124
memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.memory.action.conversation; | ||
|
||
import static org.opensearch.action.ValidateActions.addValidationError; | ||
|
||
import java.io.IOException; | ||
|
||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.ActionRequestValidationException; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.ml.common.conversation.ActionConstants; | ||
import org.opensearch.rest.RestRequest; | ||
|
||
import lombok.Getter; | ||
|
||
/** | ||
* ActionRequest for get traces | ||
*/ | ||
public class GetTracesRequest extends ActionRequest { | ||
@Getter | ||
private String interactionId; | ||
@Getter | ||
private int maxResults = ActionConstants.DEFAULT_MAX_RESULTS; | ||
@Getter | ||
private int from = 0; | ||
|
||
/** | ||
* Constructor | ||
* @param interactionId UID of the interaction to get traces from | ||
*/ | ||
public GetTracesRequest(String interactionId) { | ||
this.interactionId = interactionId; | ||
} | ||
|
||
/** | ||
* Constructor | ||
* @param interactionId UID of the conversation to get interactions from | ||
* @param maxResults number of interactions to retrieve | ||
*/ | ||
public GetTracesRequest(String interactionId, int maxResults) { | ||
this.interactionId = interactionId; | ||
this.maxResults = maxResults; | ||
} | ||
|
||
/** | ||
* Constructor | ||
* @param interactionId UID of the conversation to get interactions from | ||
* @param maxResults number of interactions to retrieve | ||
* @param from position of first interaction to retrieve | ||
*/ | ||
public GetTracesRequest(String interactionId, int maxResults, int from) { | ||
this.interactionId = interactionId; | ||
this.maxResults = maxResults; | ||
this.from = from; | ||
} | ||
|
||
/** | ||
* Constructor | ||
* @param in streaminput to read this from. assumes there was a GetTracesRequest.writeTo | ||
* @throws IOException if there wasn't a GIR in the stream | ||
*/ | ||
public GetTracesRequest(StreamInput in) throws IOException { | ||
super(in); | ||
this.interactionId = in.readString(); | ||
this.maxResults = in.readInt(); | ||
this.from = in.readInt(); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
out.writeString(interactionId); | ||
out.writeInt(maxResults); | ||
out.writeInt(from); | ||
} | ||
|
||
@Override | ||
public ActionRequestValidationException validate() { | ||
ActionRequestValidationException exception = null; | ||
if (interactionId == null) { | ||
exception = addValidationError("Traces must be retrieved from an interaction", exception); | ||
} | ||
if (maxResults <= 0) { | ||
exception = addValidationError("The number of traces to retrieve must be positive", exception); | ||
} | ||
if (from < 0) { | ||
exception = addValidationError("The starting position must be nonnegative", exception); | ||
} | ||
|
||
return exception; | ||
} | ||
|
||
/** | ||
* Makes a GetTracesRequest out of a RestRequest | ||
* @param request Rest Request representing a get traces request | ||
* @return a new GetTracesRequest | ||
* @throws IOException if something goes wrong | ||
*/ | ||
public static GetTracesRequest fromRestRequest(RestRequest request) throws IOException { | ||
String cid = request.param(ActionConstants.RESPONSE_INTERACTION_ID_FIELD); | ||
if (request.hasParam(ActionConstants.NEXT_TOKEN_FIELD)) { | ||
int from = Integer.parseInt(request.param(ActionConstants.NEXT_TOKEN_FIELD)); | ||
if (request.hasParam(ActionConstants.REQUEST_MAX_RESULTS_FIELD)) { | ||
int maxResults = Integer.parseInt(request.param(ActionConstants.REQUEST_MAX_RESULTS_FIELD)); | ||
return new GetTracesRequest(cid, maxResults, from); | ||
} else { | ||
return new GetTracesRequest(cid, ActionConstants.DEFAULT_MAX_RESULTS, from); | ||
} | ||
} else { | ||
if (request.hasParam(ActionConstants.REQUEST_MAX_RESULTS_FIELD)) { | ||
int maxResults = Integer.parseInt(request.param(ActionConstants.REQUEST_MAX_RESULTS_FIELD)); | ||
return new GetTracesRequest(cid, maxResults); | ||
} else { | ||
return new GetTracesRequest(cid); | ||
} | ||
} | ||
} | ||
|
||
} |
75 changes: 75 additions & 0 deletions
75
memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesResponse.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.memory.action.conversation; | ||
|
||
import java.io.IOException; | ||
import java.util.List; | ||
|
||
import org.opensearch.core.action.ActionResponse; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.core.xcontent.ToXContent; | ||
import org.opensearch.core.xcontent.ToXContentObject; | ||
import org.opensearch.core.xcontent.XContentBuilder; | ||
import org.opensearch.ml.common.conversation.ActionConstants; | ||
import org.opensearch.ml.common.conversation.Interaction; | ||
|
||
import lombok.AllArgsConstructor; | ||
import lombok.Getter; | ||
|
||
/** | ||
* Action Response for get traces for an interaction | ||
*/ | ||
@AllArgsConstructor | ||
public class GetTracesResponse extends ActionResponse implements ToXContentObject { | ||
@Getter | ||
private List<Interaction> traces; | ||
@Getter | ||
private int nextToken; | ||
private boolean hasMoreTokens; | ||
|
||
/** | ||
* Constructor | ||
* @param in stream input; assumes GetTracesResponse.writeTo was called | ||
* @throws IOException if there's not a G.I.R. in the stream | ||
*/ | ||
public GetTracesResponse(StreamInput in) throws IOException { | ||
super(in); | ||
traces = in.readList(Interaction::fromStream); | ||
nextToken = in.readInt(); | ||
hasMoreTokens = in.readBoolean(); | ||
} | ||
|
||
public void writeTo(StreamOutput out) throws IOException { | ||
out.writeList(traces); | ||
out.writeInt(nextToken); | ||
out.writeBoolean(hasMoreTokens); | ||
} | ||
|
||
/** | ||
* Are there more pages in this search results | ||
* @return whether there are more traces in this search | ||
*/ | ||
public boolean hasMorePages() { | ||
return hasMoreTokens; | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { | ||
builder.startObject(); | ||
builder.startArray(ActionConstants.RESPONSE_TRACES_LIST_FIELD); | ||
for (Interaction trace : traces) { | ||
trace.toXContent(builder, params); | ||
} | ||
builder.endArray(); | ||
if (hasMoreTokens) { | ||
builder.field(ActionConstants.NEXT_TOKEN_FIELD, nextToken); | ||
} | ||
builder.endObject(); | ||
return builder; | ||
} | ||
|
||
} |
66 changes: 66 additions & 0 deletions
66
.../src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesTransportAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.memory.action.conversation; | ||
|
||
import java.util.List; | ||
|
||
import org.opensearch.action.support.ActionFilters; | ||
import org.opensearch.action.support.HandledTransportAction; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.cluster.service.ClusterService; | ||
import org.opensearch.common.inject.Inject; | ||
import org.opensearch.common.util.concurrent.ThreadContext; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.ml.common.conversation.Interaction; | ||
import org.opensearch.ml.memory.ConversationalMemoryHandler; | ||
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; | ||
import org.opensearch.tasks.Task; | ||
import org.opensearch.transport.TransportService; | ||
|
||
import lombok.extern.log4j.Log4j2; | ||
|
||
@Log4j2 | ||
public class GetTracesTransportAction extends HandledTransportAction<GetTracesRequest, GetTracesResponse> { | ||
private Client client; | ||
private ConversationalMemoryHandler cmHandler; | ||
|
||
/** | ||
* Constructor | ||
* @param transportService for inter-node communications | ||
* @param actionFilters for filtering actions | ||
* @param cmHandler Handler for conversational memory operations | ||
* @param client OS Client for dealing with OS | ||
* @param clusterService for some cluster ops | ||
*/ | ||
@Inject | ||
public GetTracesTransportAction( | ||
TransportService transportService, | ||
ActionFilters actionFilters, | ||
OpenSearchConversationalMemoryHandler cmHandler, | ||
Client client, | ||
ClusterService clusterService | ||
) { | ||
super(GetTracesAction.NAME, transportService, actionFilters, GetTracesRequest::new); | ||
this.client = client; | ||
this.cmHandler = cmHandler; | ||
} | ||
|
||
@Override | ||
public void doExecute(Task task, GetTracesRequest request, ActionListener<GetTracesResponse> actionListener) { | ||
int maxResults = request.getMaxResults(); | ||
int from = request.getFrom(); | ||
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { | ||
ActionListener<GetTracesResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); | ||
ActionListener<List<Interaction>> al = ActionListener.wrap(tracesList -> { | ||
internalListener.onResponse(new GetTracesResponse(tracesList, from + maxResults, tracesList.size() == maxResults)); | ||
}, e -> { internalListener.onFailure(e); }); | ||
cmHandler.getTraces(request.getInteractionId(), from, maxResults, al); | ||
} catch (Exception e) { | ||
log.error("Failed to get traces for conversation " + request.getInteractionId(), e); | ||
actionListener.onFailure(e); | ||
} | ||
} | ||
} |
101 changes: 101 additions & 0 deletions
101
memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesRequestTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.memory.action.conversation; | ||
|
||
import java.io.IOException; | ||
import java.util.Map; | ||
|
||
import org.opensearch.common.io.stream.BytesStreamOutput; | ||
import org.opensearch.core.common.bytes.BytesReference; | ||
import org.opensearch.core.common.io.stream.BytesStreamInput; | ||
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
import org.opensearch.ml.common.conversation.ActionConstants; | ||
import org.opensearch.rest.RestRequest; | ||
import org.opensearch.test.OpenSearchTestCase; | ||
import org.opensearch.test.rest.FakeRestRequest; | ||
|
||
public class GetTracesRequestTests extends OpenSearchTestCase { | ||
|
||
public void testConstructorsAndStreaming() throws IOException { | ||
GetTracesRequest request = new GetTracesRequest("test-iid"); | ||
assert (request.validate() == null); | ||
assert (request.getInteractionId().equals("test-iid")); | ||
assert (request.getFrom() == 0); | ||
assert (request.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS); | ||
|
||
GetTracesRequest req2 = new GetTracesRequest("test-iid2", 3); | ||
assert (req2.validate() == null); | ||
assert (req2.getInteractionId().equals("test-iid2")); | ||
assert (req2.getFrom() == 0); | ||
assert (req2.getMaxResults() == 3); | ||
|
||
GetTracesRequest req3 = new GetTracesRequest("test-iid3", 4, 5); | ||
assert (req3.validate() == null); | ||
assert (req3.getInteractionId().equals("test-iid3")); | ||
assert (req3.getFrom() == 5); | ||
assert (req3.getMaxResults() == 4); | ||
|
||
BytesStreamOutput outbytes = new BytesStreamOutput(); | ||
StreamOutput osso = new OutputStreamStreamOutput(outbytes); | ||
request.writeTo(osso); | ||
StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); | ||
GetTracesRequest req4 = new GetTracesRequest(in); | ||
assert (req4.validate() == null); | ||
assert (req4.getInteractionId().equals("test-iid")); | ||
assert (req4.getFrom() == 0); | ||
assert (req4.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS); | ||
} | ||
|
||
public void testBadValues_thenFail() { | ||
String nullstr = null; | ||
GetTracesRequest request = new GetTracesRequest(nullstr); | ||
assert (request.validate().validationErrors().get(0).equals("Traces must be retrieved from an interaction")); | ||
assert (request.validate().validationErrors().size() == 1); | ||
|
||
request = new GetTracesRequest("iid", -2); | ||
assert (request.validate().validationErrors().size() == 1); | ||
assert (request.validate().validationErrors().get(0).equals("The number of traces to retrieve must be positive")); | ||
|
||
request = new GetTracesRequest("iid", 2, -2); | ||
assert (request.validate().validationErrors().size() == 1); | ||
assert (request.validate().validationErrors().get(0).equals("The starting position must be nonnegative")); | ||
} | ||
|
||
public void testFromRestRequest() throws IOException { | ||
Map<String, String> basic = Map.of(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, "iid1"); | ||
Map<String, String> maxResOnly = Map | ||
.of(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, "iid2", ActionConstants.REQUEST_MAX_RESULTS_FIELD, "4"); | ||
Map<String, String> nextTokOnly = Map | ||
.of(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, "iid3", ActionConstants.NEXT_TOKEN_FIELD, "6"); | ||
Map<String, String> bothFields = Map | ||
.of( | ||
ActionConstants.RESPONSE_INTERACTION_ID_FIELD, | ||
"iid4", | ||
ActionConstants.REQUEST_MAX_RESULTS_FIELD, | ||
"2", | ||
ActionConstants.NEXT_TOKEN_FIELD, | ||
"7" | ||
); | ||
RestRequest req1 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(basic).build(); | ||
RestRequest req2 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(maxResOnly).build(); | ||
RestRequest req3 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(nextTokOnly).build(); | ||
RestRequest req4 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(bothFields).build(); | ||
GetTracesRequest gir1 = GetTracesRequest.fromRestRequest(req1); | ||
GetTracesRequest gir2 = GetTracesRequest.fromRestRequest(req2); | ||
GetTracesRequest gir3 = GetTracesRequest.fromRestRequest(req3); | ||
GetTracesRequest gir4 = GetTracesRequest.fromRestRequest(req4); | ||
|
||
assert (gir1.validate() == null && gir2.validate() == null && gir3.validate() == null && gir4.validate() == null); | ||
assert (gir1.getInteractionId().equals("iid1") && gir2.getInteractionId().equals("iid2")); | ||
assert (gir3.getInteractionId().equals("iid3") && gir4.getInteractionId().equals("iid4")); | ||
assert (gir1.getFrom() == 0 && gir2.getFrom() == 0 && gir3.getFrom() == 6 && gir4.getFrom() == 7); | ||
assert (gir1.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS && gir2.getMaxResults() == 4); | ||
assert (gir3.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS && gir4.getMaxResults() == 2); | ||
} | ||
} |
Oops, something went wrong.