From 2a2ccd09d125cb702e5100da5d731ca6f53b0c84 Mon Sep 17 00:00:00 2001 From: Xun Zhang Date: Thu, 14 Dec 2023 17:12:57 -0800 Subject: [PATCH] refactor memory manager and Get Trace actions Signed-off-by: Xun Zhang --- .../action/conversation/GetTracesAction.java | 23 ++ .../action/conversation/GetTracesRequest.java | 124 ++++++++ .../conversation/GetTracesResponse.java | 75 +++++ .../GetTracesTransportAction.java | 66 +++++ .../conversation/GetTracesRequestTests.java | 101 +++++++ .../conversation/GetTracesResponseTests.java | 104 +++++++ .../GetTracesTransportActionTests.java | 168 +++++++++++ .../ml/engine/memory/MLMemoryManager.java | 82 +----- .../engine/memory/MLMemoryManagerTests.java | 277 ++++++++++++++++++ .../ml/plugin/MachineLearningPlugin.java | 10 +- .../ml/rest/RestMemoryGetTracesAction.java | 37 +++ .../rest/RestMemoryGetTracesActionTests.java | 64 ++++ 12 files changed, 1061 insertions(+), 70 deletions(-) create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesAction.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesRequest.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesResponse.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesTransportAction.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesRequestTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesResponseTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesTransportActionTests.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetTracesAction.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetTracesActionTests.java diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesAction.java new file mode 100644 index 0000000000..0117df94b5 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesAction.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.action.ActionType; + +/** + * Action to return the traces associated with an interaction + */ +public class GetTracesAction extends ActionType { + /** Instance of this */ + public static final GetTracesAction INSTANCE = new GetTracesAction(); + /** Name of this action */ + public static final String NAME = "cluster:admin/opensearch/ml/memory/trace/get"; + + private GetTracesAction() { + super(NAME, GetTracesResponse::new); + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesRequest.java new file mode 100644 index 0000000000..9b65f78148 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesRequest.java @@ -0,0 +1,124 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; + +import lombok.Getter; + +/** + * ActionRequest for get traces + */ +public class GetTracesRequest extends ActionRequest { + @Getter + private String interactionId; + @Getter + private int maxResults = ActionConstants.DEFAULT_MAX_RESULTS; + @Getter + private int from = 0; + + /** + * Constructor + * @param interactionId UID of the interaction to get traces from + */ + public GetTracesRequest(String interactionId) { + this.interactionId = interactionId; + } + + /** + * Constructor + * @param interactionId UID of the conversation to get interactions from + * @param maxResults number of interactions to retrieve + */ + public GetTracesRequest(String interactionId, int maxResults) { + this.interactionId = interactionId; + this.maxResults = maxResults; + } + + /** + * Constructor + * @param interactionId UID of the conversation to get interactions from + * @param maxResults number of interactions to retrieve + * @param from position of first interaction to retrieve + */ + public GetTracesRequest(String interactionId, int maxResults, int from) { + this.interactionId = interactionId; + this.maxResults = maxResults; + this.from = from; + } + + /** + * Constructor + * @param in streaminput to read this from. assumes there was a GetTracesRequest.writeTo + * @throws IOException if there wasn't a GIR in the stream + */ + public GetTracesRequest(StreamInput in) throws IOException { + super(in); + this.interactionId = in.readString(); + this.maxResults = in.readInt(); + this.from = in.readInt(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(interactionId); + out.writeInt(maxResults); + out.writeInt(from); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (interactionId == null) { + exception = addValidationError("Traces must be retrieved from an interaction", exception); + } + if (maxResults <= 0) { + exception = addValidationError("The number of traces to retrieve must be positive", exception); + } + if (from < 0) { + exception = addValidationError("The starting position must be nonnegative", exception); + } + + return exception; + } + + /** + * Makes a GetTracesRequest out of a RestRequest + * @param request Rest Request representing a get traces request + * @return a new GetTracesRequest + * @throws IOException if something goes wrong + */ + public static GetTracesRequest fromRestRequest(RestRequest request) throws IOException { + String cid = request.param(ActionConstants.RESPONSE_INTERACTION_ID_FIELD); + if (request.hasParam(ActionConstants.NEXT_TOKEN_FIELD)) { + int from = Integer.parseInt(request.param(ActionConstants.NEXT_TOKEN_FIELD)); + if (request.hasParam(ActionConstants.REQUEST_MAX_RESULTS_FIELD)) { + int maxResults = Integer.parseInt(request.param(ActionConstants.REQUEST_MAX_RESULTS_FIELD)); + return new GetTracesRequest(cid, maxResults, from); + } else { + return new GetTracesRequest(cid, ActionConstants.DEFAULT_MAX_RESULTS, from); + } + } else { + if (request.hasParam(ActionConstants.REQUEST_MAX_RESULTS_FIELD)) { + int maxResults = Integer.parseInt(request.param(ActionConstants.REQUEST_MAX_RESULTS_FIELD)); + return new GetTracesRequest(cid, maxResults); + } else { + return new GetTracesRequest(cid); + } + } + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesResponse.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesResponse.java new file mode 100644 index 0000000000..df27f655bf --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesResponse.java @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.common.conversation.Interaction; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * Action Response for get traces for an interaction + */ +@AllArgsConstructor +public class GetTracesResponse extends ActionResponse implements ToXContentObject { + @Getter + private List traces; + @Getter + private int nextToken; + private boolean hasMoreTokens; + + /** + * Constructor + * @param in stream input; assumes GetTracesResponse.writeTo was called + * @throws IOException if there's not a G.I.R. in the stream + */ + public GetTracesResponse(StreamInput in) throws IOException { + super(in); + traces = in.readList(Interaction::fromStream); + nextToken = in.readInt(); + hasMoreTokens = in.readBoolean(); + } + + public void writeTo(StreamOutput out) throws IOException { + out.writeList(traces); + out.writeInt(nextToken); + out.writeBoolean(hasMoreTokens); + } + + /** + * Are there more pages in this search results + * @return whether there are more traces in this search + */ + public boolean hasMorePages() { + return hasMoreTokens; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.startArray(ActionConstants.RESPONSE_TRACES_LIST_FIELD); + for (Interaction trace : traces) { + trace.toXContent(builder, params); + } + builder.endArray(); + if (hasMoreTokens) { + builder.field(ActionConstants.NEXT_TOKEN_FIELD, nextToken); + } + builder.endObject(); + return builder; + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesTransportAction.java new file mode 100644 index 0000000000..cd57d1823c --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesTransportAction.java @@ -0,0 +1,66 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import java.util.List; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.memory.ConversationalMemoryHandler; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class GetTracesTransportAction extends HandledTransportAction { + private Client client; + private ConversationalMemoryHandler cmHandler; + + /** + * Constructor + * @param transportService for inter-node communications + * @param actionFilters for filtering actions + * @param cmHandler Handler for conversational memory operations + * @param client OS Client for dealing with OS + * @param clusterService for some cluster ops + */ + @Inject + public GetTracesTransportAction( + TransportService transportService, + ActionFilters actionFilters, + OpenSearchConversationalMemoryHandler cmHandler, + Client client, + ClusterService clusterService + ) { + super(GetTracesAction.NAME, transportService, actionFilters, GetTracesRequest::new); + this.client = client; + this.cmHandler = cmHandler; + } + + @Override + public void doExecute(Task task, GetTracesRequest request, ActionListener actionListener) { + int maxResults = request.getMaxResults(); + int from = request.getFrom(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { + ActionListener internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); + ActionListener> al = ActionListener.wrap(tracesList -> { + internalListener.onResponse(new GetTracesResponse(tracesList, from + maxResults, tracesList.size() == maxResults)); + }, e -> { internalListener.onFailure(e); }); + cmHandler.getTraces(request.getInteractionId(), from, maxResults, al); + } catch (Exception e) { + log.error("Failed to get traces for conversation " + request.getInteractionId(), e); + actionListener.onFailure(e); + } + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesRequestTests.java new file mode 100644 index 0000000000..0b88bd48c6 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesRequestTests.java @@ -0,0 +1,101 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.util.Map; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +public class GetTracesRequestTests extends OpenSearchTestCase { + + public void testConstructorsAndStreaming() throws IOException { + GetTracesRequest request = new GetTracesRequest("test-iid"); + assert (request.validate() == null); + assert (request.getInteractionId().equals("test-iid")); + assert (request.getFrom() == 0); + assert (request.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS); + + GetTracesRequest req2 = new GetTracesRequest("test-iid2", 3); + assert (req2.validate() == null); + assert (req2.getInteractionId().equals("test-iid2")); + assert (req2.getFrom() == 0); + assert (req2.getMaxResults() == 3); + + GetTracesRequest req3 = new GetTracesRequest("test-iid3", 4, 5); + assert (req3.validate() == null); + assert (req3.getInteractionId().equals("test-iid3")); + assert (req3.getFrom() == 5); + assert (req3.getMaxResults() == 4); + + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + request.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + GetTracesRequest req4 = new GetTracesRequest(in); + assert (req4.validate() == null); + assert (req4.getInteractionId().equals("test-iid")); + assert (req4.getFrom() == 0); + assert (req4.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS); + } + + public void testBadValues_thenFail() { + String nullstr = null; + GetTracesRequest request = new GetTracesRequest(nullstr); + assert (request.validate().validationErrors().get(0).equals("Traces must be retrieved from an interaction")); + assert (request.validate().validationErrors().size() == 1); + + request = new GetTracesRequest("iid", -2); + assert (request.validate().validationErrors().size() == 1); + assert (request.validate().validationErrors().get(0).equals("The number of traces to retrieve must be positive")); + + request = new GetTracesRequest("iid", 2, -2); + assert (request.validate().validationErrors().size() == 1); + assert (request.validate().validationErrors().get(0).equals("The starting position must be nonnegative")); + } + + public void testFromRestRequest() throws IOException { + Map basic = Map.of(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, "iid1"); + Map maxResOnly = Map + .of(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, "iid2", ActionConstants.REQUEST_MAX_RESULTS_FIELD, "4"); + Map nextTokOnly = Map + .of(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, "iid3", ActionConstants.NEXT_TOKEN_FIELD, "6"); + Map bothFields = Map + .of( + ActionConstants.RESPONSE_INTERACTION_ID_FIELD, + "iid4", + ActionConstants.REQUEST_MAX_RESULTS_FIELD, + "2", + ActionConstants.NEXT_TOKEN_FIELD, + "7" + ); + RestRequest req1 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(basic).build(); + RestRequest req2 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(maxResOnly).build(); + RestRequest req3 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(nextTokOnly).build(); + RestRequest req4 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(bothFields).build(); + GetTracesRequest gir1 = GetTracesRequest.fromRestRequest(req1); + GetTracesRequest gir2 = GetTracesRequest.fromRestRequest(req2); + GetTracesRequest gir3 = GetTracesRequest.fromRestRequest(req3); + GetTracesRequest gir4 = GetTracesRequest.fromRestRequest(req4); + + assert (gir1.validate() == null && gir2.validate() == null && gir3.validate() == null && gir4.validate() == null); + assert (gir1.getInteractionId().equals("iid1") && gir2.getInteractionId().equals("iid2")); + assert (gir3.getInteractionId().equals("iid3") && gir4.getInteractionId().equals("iid4")); + assert (gir1.getFrom() == 0 && gir2.getFrom() == 0 && gir3.getFrom() == 6 && gir4.getFrom() == 7); + assert (gir1.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS && gir2.getMaxResults() == 4); + assert (gir3.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS && gir4.getMaxResults() == 2); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesResponseTests.java new file mode 100644 index 0000000000..8a16d24fb9 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesResponseTests.java @@ -0,0 +1,104 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collections; +import java.util.List; + +import org.apache.lucene.search.spell.LevenshteinDistance; +import org.junit.Before; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.test.OpenSearchTestCase; + +public class GetTracesResponseTests extends OpenSearchTestCase { + List traces; + + @Before + public void setup() { + traces = List + .of( + new Interaction( + "id0", + Instant.now(), + "cid", + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta"), + "parent_id", + 1 + ), + new Interaction( + "id1", + Instant.now(), + "cid", + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta"), + "parent_id", + 2 + ), + new Interaction( + "id2", + Instant.now(), + "cid", + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta"), + "parent_id", + 3 + + ) + ); + } + + public void testGetInteractionsResponseStreaming() throws IOException { + GetTracesResponse response = new GetTracesResponse(traces, 4, true); + assert (response.getTraces().equals(traces)); + assert (response.getNextToken() == 4); + assert (response.hasMorePages()); + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + response.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + GetTracesResponse newResp = new GetTracesResponse(in); + assert (newResp.getTraces().equals(traces)); + assert (newResp.getNextToken() == 4); + assert (newResp.hasMorePages()); + } + + public void testToXContent_MoreTokens() throws IOException { + GetTracesResponse response = new GetTracesResponse(traces.subList(0, 1), 2, true); + Interaction trace = response.getTraces().get(0); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + String result = BytesReference.bytes(builder).utf8ToString(); + System.out.println(result); + String expected = "{\"traces\":[{\"conversation_id\":\"cid\",\"interaction_id\":\"id0\",\"create_time\":" + + trace.getCreateTime() + + ",\"input\":\"input\",\"prompt_template\":\"pt\",\"response\":\"response\",\"origin\":\"origin\",\"additional_info\":{\"metadata\":\"some meta\"},\"parent_interaction_id\":\"parent_id\",\"trace_number\":1}],\"next_token\":2}"; + // Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness + LevenshteinDistance ld = new LevenshteinDistance(); + assert (ld.getDistance(result, expected) > 0.95); + } + +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesTransportActionTests.java new file mode 100644 index 0000000000..d855cd69ac --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesTransportActionTests.java @@ -0,0 +1,168 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collections; +import java.util.List; +import java.util.Set; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class GetTracesTransportActionTests extends OpenSearchTestCase { + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + ClusterService clusterService; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + OpenSearchConversationalMemoryHandler cmHandler; + + GetTracesRequest request; + GetTracesTransportAction action; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + + @SuppressWarnings("unchecked") + ActionListener al = (ActionListener) Mockito.mock(ActionListener.class); + this.actionListener = al; + this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class); + + this.request = new GetTracesRequest("test-iid"); + + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); + this.threadContext = new ThreadContext(settings); + when(this.client.threadPool()).thenReturn(this.threadPool); + when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); + when(this.clusterService.getSettings()).thenReturn(settings); + when(this.clusterService.getClusterSettings()) + .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + + this.action = spy(new GetTracesTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + } + + public void testGetTraces_noMorePages() { + Interaction testTrace = new Interaction( + "test-trace", + Instant.now(), + "test-cid", + "test-input", + "pt", + "test-response", + "test-origin", + Collections.singletonMap("metadata", "some meta"), + "parent-id", + 1 + ); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + listener.onResponse(List.of(testTrace)); + return null; + }).when(cmHandler).getTraces(any(), anyInt(), anyInt(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetTracesResponse.class); + verify(actionListener).onResponse(argCaptor.capture()); + List traces = argCaptor.getValue().getTraces(); + assert (traces.size() == 1); + Interaction trace = traces.get(0); + assert (trace.equals(testTrace)); + assert (!argCaptor.getValue().hasMorePages()); + } + + public void testGetTraces_MorePages() { + Interaction testTrace = new Interaction( + "test-trace", + Instant.now(), + "test-cid", + "test-input", + "pt", + "test-response", + "test-origin", + Collections.singletonMap("metadata", "some meta"), + "parent_id", + 1 + ); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + listener.onResponse(List.of(testTrace)); + return null; + }).when(cmHandler).getTraces(any(), anyInt(), anyInt(), any()); + GetTracesRequest shortPageRequest = new GetTracesRequest("test-trace", 1); + action.doExecute(null, shortPageRequest, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetTracesResponse.class); + verify(actionListener).onResponse(argCaptor.capture()); + List traces = argCaptor.getValue().getTraces(); + assert (traces.size() == 1); + Interaction trace = traces.get(0); + assert (trace.equals(testTrace)); + assert (argCaptor.getValue().hasMorePages()); + } + + public void testGetTracesFails_thenFail() { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + listener.onFailure(new Exception("Testing Failure")); + return null; + }).when(cmHandler).getTraces(any(), anyInt(), anyInt(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Testing Failure")); + } + + public void testDoExecuteFails_thenFail() { + doThrow(new RuntimeException("Failure in doExecute")).when(cmHandler).getTraces(any(), anyInt(), anyInt(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failure in doExecute")); + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java index 15ebf14f02..9b6707c696 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java @@ -6,23 +6,12 @@ package org.opensearch.ml.engine.memory; import java.util.HashMap; -import java.util.LinkedList; import java.util.List; import java.util.Map; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; -import org.opensearch.client.Requests; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; -import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.ExistsQueryBuilder; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.index.query.TermQueryBuilder; -import org.opensearch.ml.common.conversation.ConversationalIndexConstants; import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.ml.memory.action.conversation.CreateConversationAction; import org.opensearch.ml.memory.action.conversation.CreateConversationRequest; @@ -33,14 +22,12 @@ import org.opensearch.ml.memory.action.conversation.GetInteractionsAction; import org.opensearch.ml.memory.action.conversation.GetInteractionsRequest; import org.opensearch.ml.memory.action.conversation.GetInteractionsResponse; +import org.opensearch.ml.memory.action.conversation.GetTracesAction; +import org.opensearch.ml.memory.action.conversation.GetTracesRequest; +import org.opensearch.ml.memory.action.conversation.GetTracesResponse; import org.opensearch.ml.memory.action.conversation.UpdateInteractionAction; import org.opensearch.ml.memory.action.conversation.UpdateInteractionRequest; -import org.opensearch.ml.memory.index.ConversationMetaIndex; -import org.opensearch.search.SearchHit; -import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.search.sort.SortOrder; -import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import lombok.AllArgsConstructor; @@ -54,9 +41,6 @@ public class MLMemoryManager { private Client client; - private ClusterService clusterService; - private ConversationMetaIndex conversationMetaIndex; - private final String indexName = ConversationalIndexConstants.INTERACTIONS_INDEX_NAME; /** * Create a new Conversation @@ -128,6 +112,7 @@ public void createInteraction( * @param actionListener get all the final interactions that are not traces */ public void getFinalInteractions(String conversationId, int lastNInteraction, ActionListener> actionListener) { + Preconditions.checkNotNull(conversationId); Preconditions.checkArgument(lastNInteraction > 0, "lastN must be at least 1."); log.info("Getting Interactions, conversationId {}, lastN {}", conversationId, lastNInteraction); @@ -148,56 +133,17 @@ public void getFinalInteractions(String conversationId, int lastNInteraction, Ac * @param actionListener get all the trace interactions that are only traces */ public void getTraces(String parentInteractionId, ActionListener> actionListener) { - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { - if (!clusterService.state().metadata().hasIndex(indexName)) { - actionListener.onResponse(List.of()); - return; - } - innerGetTraces(parentInteractionId, actionListener); - } catch (Exception e) { - actionListener.onFailure(e); - } - } + Preconditions.checkNotNull(parentInteractionId); + log.info("Getting traces for conversationId {}", parentInteractionId); - @VisibleForTesting - void innerGetTraces(String parentInteractionId, ActionListener> listener) { - SearchRequest searchRequest = Requests.searchRequest(indexName); - - // Build the query - BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); - - // Add the ExistsQueryBuilder for checking null values - ExistsQueryBuilder existsQueryBuilder = QueryBuilders.existsQuery(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD); - boolQueryBuilder.must(existsQueryBuilder); - - // Add the TermQueryBuilder for another field - TermQueryBuilder termQueryBuilder = QueryBuilders - .termQuery(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, parentInteractionId); - boolQueryBuilder.must(termQueryBuilder); - - // Set the query to the search source - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(boolQueryBuilder); - searchRequest.source(searchSourceBuilder); - - searchRequest.source().sort(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, SortOrder.ASC); - try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { - ActionListener> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); - ActionListener al = ActionListener.wrap(response -> { - List result = new LinkedList(); - for (SearchHit hit : response.getHits()) { - result.add(Interaction.fromSearchHit(hit)); - } - internalListener.onResponse(result); - }, e -> { internalListener.onFailure(e); }); - client - .admin() - .indices() - .refresh(Requests.refreshRequest(indexName), ActionListener.wrap(r -> { client.search(searchRequest, al); }, e -> { - internalListener.onFailure(e); - })); - } catch (Exception e) { - listener.onFailure(e); + ActionListener al = ActionListener.wrap(getTracesResponse -> { + actionListener.onResponse(getTracesResponse.getTraces()); + }, e -> { actionListener.onFailure(e); }); + + try { + client.execute(GetTracesAction.INSTANCE, new GetTracesRequest(parentInteractionId), al); + } catch (RuntimeException exception) { + actionListener.onFailure(exception); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java new file mode 100644 index 0000000000..b3a5f0da56 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java @@ -0,0 +1,277 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.memory; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD; + +import java.time.Instant; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.memory.action.conversation.CreateConversationAction; +import org.opensearch.ml.memory.action.conversation.CreateConversationRequest; +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; +import org.opensearch.ml.memory.action.conversation.CreateInteractionAction; +import org.opensearch.ml.memory.action.conversation.CreateInteractionRequest; +import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; +import org.opensearch.ml.memory.action.conversation.GetInteractionsAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionsRequest; +import org.opensearch.ml.memory.action.conversation.GetInteractionsResponse; +import org.opensearch.ml.memory.action.conversation.GetTracesAction; +import org.opensearch.ml.memory.action.conversation.GetTracesRequest; +import org.opensearch.ml.memory.action.conversation.GetTracesResponse; +import org.opensearch.ml.memory.action.conversation.UpdateInteractionAction; +import org.opensearch.ml.memory.action.conversation.UpdateInteractionRequest; + +public class MLMemoryManagerTests { + + @Mock + Client client; + + @Mock + MLMemoryManager mlMemoryManager; + + @Mock + ActionListener createConversationResponseActionListener; + + @Mock + ActionListener createInteractionResponseActionListener; + + @Mock + ActionListener> interactionListActionListener; + + @Mock + ActionListener updateResponseActionListener; + + String conversationName; + String applicationType; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + mlMemoryManager = new MLMemoryManager(client); + conversationName = "new conversation"; + applicationType = "ml application"; + } + + @Test + public void testCreateConversation() { + ArgumentCaptor captor = ArgumentCaptor.forClass(CreateConversationRequest.class); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(2); + al.onResponse(new CreateConversationResponse("conversation-id")); + return null; + }).when(client).execute(any(), any(), any()); + + mlMemoryManager.createConversation(conversationName, applicationType, createConversationResponseActionListener); + + verify(client, times(1)) + .execute(eq(CreateConversationAction.INSTANCE), captor.capture(), eq(createConversationResponseActionListener)); + assertEquals(conversationName, captor.getValue().getName()); + assertEquals(applicationType, captor.getValue().getApplicationType()); + } + + @Test + public void testCreateConversationFails_thenFail() { + doThrow(new RuntimeException("Failure in runtime")).when(client).execute(any(), any(), any()); + mlMemoryManager.createConversation(conversationName, applicationType, createConversationResponseActionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(createConversationResponseActionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failure in runtime")); + } + + @Test + public void testCreateInteraction() { + ArgumentCaptor captor = ArgumentCaptor.forClass(CreateInteractionRequest.class); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(2); + al.onResponse(new CreateInteractionResponse("interaction-id")); + return null; + }).when(client).execute(any(), any(), any()); + + mlMemoryManager + .createInteraction( + "conversationId", + "input", + "prompt", + "response", + "origin", + Collections.singletonMap("feedback", "thumbsup"), + "parent-id", + 1, + createInteractionResponseActionListener + ); + verify(client, times(1)) + .execute(eq(CreateInteractionAction.INSTANCE), captor.capture(), eq(createInteractionResponseActionListener)); + assertEquals("conversationId", captor.getValue().getConversationId()); + assertEquals("input", captor.getValue().getInput()); + assertEquals("prompt", captor.getValue().getPromptTemplate()); + assertEquals("response", captor.getValue().getResponse()); + assertEquals("origin", captor.getValue().getOrigin()); + assertEquals(Collections.singletonMap("feedback", "thumbsup"), captor.getValue().getAdditionalInfo()); + assertEquals("parent-id", captor.getValue().getParentIid()); + assertEquals("1", captor.getValue().getTraceNumber().toString()); + } + + @Test + public void testCreateInteractionFails_thenFail() { + doThrow(new RuntimeException("Failure in runtime")).when(client).execute(any(), any(), any()); + mlMemoryManager + .createInteraction( + "conversationId", + "input", + "prompt", + "response", + "origin", + Collections.singletonMap("feedback", "thumbsup"), + "parent-id", + 1, + createInteractionResponseActionListener + ); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(createInteractionResponseActionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failure in runtime")); + } + + @Test + public void testGetInteractions() { + List interactions = List + .of( + new Interaction( + "id0", + Instant.now(), + "cid", + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta") + ) + ); + ArgumentCaptor captor = ArgumentCaptor.forClass(GetInteractionsRequest.class); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(2); + GetInteractionsResponse getInteractionsResponse = new GetInteractionsResponse(interactions, 4, false); + al.onResponse(getInteractionsResponse); + return null; + }).when(client).execute(any(), any(), any()); + + mlMemoryManager.getFinalInteractions("cid", 10, interactionListActionListener); + + verify(client, times(1)).execute(eq(GetInteractionsAction.INSTANCE), captor.capture(), any()); + assertEquals("cid", captor.getValue().getConversationId()); + assertEquals(0, captor.getValue().getFrom()); + assertEquals(10, captor.getValue().getMaxResults()); + } + + @Test + public void testGetInteractionFails_thenFail() { + doThrow(new RuntimeException("Failure in runtime")).when(client).execute(any(), any(), any()); + mlMemoryManager.getFinalInteractions("cid", 10, interactionListActionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(interactionListActionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failure in runtime")); + } + + @Test + public void testGetTraces() { + List traces = List + .of( + new Interaction( + "id0", + Instant.now(), + "cid", + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta"), + "parent_id", + 1 + ) + ); + ArgumentCaptor captor = ArgumentCaptor.forClass(GetTracesRequest.class); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(2); + GetTracesResponse getTracesResponse = new GetTracesResponse(traces, 4, false); + al.onResponse(getTracesResponse); + return null; + }).when(client).execute(any(), any(), any()); + + mlMemoryManager.getTraces("iid", interactionListActionListener); + + verify(client, times(1)).execute(eq(GetTracesAction.INSTANCE), captor.capture(), any()); + assertEquals("iid", captor.getValue().getInteractionId()); + assertEquals(0, captor.getValue().getFrom()); + assertEquals(10, captor.getValue().getMaxResults()); + } + + @Test + public void testGetTracesFails_thenFail() { + doThrow(new RuntimeException("Failure in runtime")).when(client).execute(any(), any(), any()); + mlMemoryManager.getTraces("cid", interactionListActionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(interactionListActionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failure in runtime")); + } + + @Test + public void testUpdateInteraction() { + Map updateContent = Map + .of(INTERACTIONS_ADDITIONAL_INFO_FIELD, Map.of("feedback", "thumbs up!"), INTERACTIONS_RESPONSE_FIELD, "response"); + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); + UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); + + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(2); + al.onResponse(updateResponse); + return null; + }).when(client).execute(any(), any(), any()); + + ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateInteractionRequest.class); + mlMemoryManager.updateInteraction("iid", updateContent, updateResponseActionListener); + verify(client, times(1)).execute(eq(UpdateInteractionAction.INSTANCE), captor.capture(), any()); + assertEquals("iid", captor.getValue().getInteractionId()); + assertEquals(1, captor.getValue().getUpdateContent().keySet().size()); + assertNotEquals(updateContent, captor.getValue().getUpdateContent()); + } + + @Test + public void testUpdateInteraction_thenFail() { + doThrow(new RuntimeException("Failure in runtime")).when(client).execute(any(), any(), any()); + mlMemoryManager + .updateInteraction( + "iid", + Map.of(INTERACTIONS_ADDITIONAL_INFO_FIELD, Map.of("feedback", "thumbs up!")), + updateResponseActionListener + ); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(updateResponseActionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failure in runtime")); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 89162beb0e..e986d7e3c1 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -147,6 +147,8 @@ import org.opensearch.ml.memory.action.conversation.GetInteractionTransportAction; import org.opensearch.ml.memory.action.conversation.GetInteractionsAction; import org.opensearch.ml.memory.action.conversation.GetInteractionsTransportAction; +import org.opensearch.ml.memory.action.conversation.GetTracesAction; +import org.opensearch.ml.memory.action.conversation.GetTracesTransportAction; import org.opensearch.ml.memory.action.conversation.SearchConversationsAction; import org.opensearch.ml.memory.action.conversation.SearchConversationsTransportAction; import org.opensearch.ml.memory.action.conversation.SearchInteractionsAction; @@ -193,6 +195,7 @@ import org.opensearch.ml.rest.RestMemoryGetConversationsAction; import org.opensearch.ml.rest.RestMemoryGetInteractionAction; import org.opensearch.ml.rest.RestMemoryGetInteractionsAction; +import org.opensearch.ml.rest.RestMemoryGetTracesAction; import org.opensearch.ml.rest.RestMemorySearchConversationsAction; import org.opensearch.ml.rest.RestMemorySearchInteractionsAction; import org.opensearch.ml.rest.RestMemoryUpdateConversationAction; @@ -329,7 +332,8 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc new ActionHandler<>(GetConversationAction.INSTANCE, GetConversationTransportAction.class), new ActionHandler<>(GetInteractionAction.INSTANCE, GetInteractionTransportAction.class), new ActionHandler<>(UpdateConversationAction.INSTANCE, UpdateConversationTransportAction.class), - new ActionHandler<>(UpdateInteractionAction.INSTANCE, UpdateInteractionTransportAction.class) + new ActionHandler<>(UpdateInteractionAction.INSTANCE, UpdateInteractionTransportAction.class), + new ActionHandler<>(GetTracesAction.INSTANCE, GetTracesTransportAction.class) ); } @@ -587,6 +591,7 @@ public List getRestHandlers( RestMemoryGetInteractionAction restGetInteractionAction = new RestMemoryGetInteractionAction(); RestMemoryUpdateConversationAction restMemoryUpdateConversationAction = new RestMemoryUpdateConversationAction(); RestMemoryUpdateInteractionAction restMemoryUpdateInteractionAction = new RestMemoryUpdateInteractionAction(); + RestMemoryGetTracesAction restMemoryGetTracesAction = new RestMemoryGetTracesAction(); return ImmutableList .of( restMLStatsAction, @@ -627,7 +632,8 @@ public List getRestHandlers( restGetConversationAction, restGetInteractionAction, restMemoryUpdateConversationAction, - restMemoryUpdateInteractionAction + restMemoryUpdateInteractionAction, + restMemoryGetTracesAction ); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetTracesAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetTracesAction.java new file mode 100644 index 0000000000..12c0815cc3 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetTracesAction.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.GetTracesAction; +import org.opensearch.ml.memory.action.conversation.GetTracesRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +public class RestMemoryGetTracesAction extends BaseRestHandler { + private final static String GET_TRACES_NAME = "conversational_get_traces"; + + @Override + public List routes() { + return List.of(new Route(RestRequest.Method.GET, ActionConstants.GET_TRACES_REST_PATH)); + } + + @Override + public String getName() { + return GET_TRACES_NAME; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + GetTracesRequest gtRequest = GetTracesRequest.fromRestRequest(request); + return channel -> client.execute(GetTracesAction.INSTANCE, gtRequest, new RestToXContentListener<>(channel)); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetTracesActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetTracesActionTests.java new file mode 100644 index 0000000000..67a91db6e8 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetTracesActionTests.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.List; +import java.util.Map; + +import org.mockito.ArgumentCaptor; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.GetTracesAction; +import org.opensearch.ml.memory.action.conversation.GetTracesRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +public class RestMemoryGetTracesActionTests extends OpenSearchTestCase { + + public void testBasics() { + RestMemoryGetTracesAction action = new RestMemoryGetTracesAction(); + assert (action.getName().equals("conversational_get_traces")); + List routes = action.routes(); + assert (routes.size() == 1); + assert (routes.get(0).equals(new RestHandler.Route(RestRequest.Method.GET, ActionConstants.GET_TRACES_REST_PATH))); + } + + public void testPrepareRequest() throws Exception { + RestMemoryGetTracesAction action = new RestMemoryGetTracesAction(); + Map params = Map + .of( + ActionConstants.RESPONSE_INTERACTION_ID_FIELD, + "iid", + ActionConstants.REQUEST_MAX_RESULTS_FIELD, + "2", + ActionConstants.NEXT_TOKEN_FIELD, + "7" + ); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + + NodeClient client = mock(NodeClient.class); + RestChannel channel = mock(RestChannel.class); + action.handleRequest(request, channel, client); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetTracesRequest.class); + verify(client, times(1)).execute(eq(GetTracesAction.INSTANCE), argCaptor.capture(), any()); + GetTracesRequest req = argCaptor.getValue(); + assert (req.getInteractionId().equals("iid")); + assert (req.getFrom() == 7); + assert (req.getMaxResults() == 2); + } + +}