Skip to content

Commit

Permalink
Added create conversation API in MLClient
Browse files Browse the repository at this point in the history
Signed-off-by: Owais Kazi <[email protected]>
  • Loading branch information
owaiskazi19 committed Mar 18, 2024
1 parent ab1e054 commit 1e74bab
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 0 deletions.
1 change: 1 addition & 0 deletions client/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ plugins {
dependencies {
implementation project(path: ":${rootProject.name}-spi", configuration: 'shadow')
implementation project(path: ":${rootProject.name}-common", configuration: 'shadow')
implementation project(path: ":${rootProject.name}-memory")
compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
testImplementation group: 'junit', name: 'junit', version: '4.13.2'
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.7.0'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;

/**
* A client to provide interfaces for machine learning jobs. This will be used by other plugins.
Expand Down Expand Up @@ -428,4 +429,22 @@ default ActionFuture<ToolMetadata> getTool(String toolName) {
*/
void getTool(String toolName, ActionListener<ToolMetadata> listener);

/**
* Create conversational memory for conversation
* @param name name of the conversation, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/memory-apis/create-memory/
* @return the result future
*/
default ActionFuture<CreateConversationResponse> createConversation(String name) {
PlainActionFuture<CreateConversationResponse> actionFuture = PlainActionFuture.newFuture();
createConversation(name, actionFuture);
return actionFuture;
}

/**
* Create conversational memory for conversation
* @param name name of the conversation, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/memory-apis/create-memory/
* @param listener action listener
*/
void createConversation(String name, ActionListener<CreateConversationResponse> listener);

}
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
import org.opensearch.ml.memory.action.conversation.CreateConversationAction;
import org.opensearch.ml.memory.action.conversation.CreateConversationRequest;
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;

import lombok.AccessLevel;
import lombok.RequiredArgsConstructor;
Expand Down Expand Up @@ -309,6 +312,11 @@ public void getTool(String toolName, ActionListener<ToolMetadata> listener) {
client.execute(MLGetToolAction.INSTANCE, mlToolGetRequest, getMlGetToolResponseActionListener(listener));
}

public void createConversation(String name, ActionListener<CreateConversationResponse> listener) {
CreateConversationRequest createConversationRequest = new CreateConversationRequest(name);
client.execute(CreateConversationAction.INSTANCE, createConversationRequest, getCreateConversationResponseActionListener(listener));
}

private ActionListener<MLToolsListResponse> getMlListToolsResponseActionListener(ActionListener<List<ToolMetadata>> listener) {
ActionListener<MLToolsListResponse> internalListener = ActionListener.wrap(mlModelListResponse -> {
listener.onResponse(mlModelListResponse.getToolMetadataList());
Expand Down Expand Up @@ -379,6 +387,16 @@ private ActionListener<MLCreateConnectorResponse> getMlCreateConnectorResponseAc
return actionListener;
}

private ActionListener<CreateConversationResponse> getCreateConversationResponseActionListener(
ActionListener<CreateConversationResponse> listener
) {
ActionListener<CreateConversationResponse> actionListener = wrapActionListener(listener, response -> {
CreateConversationResponse conversationResponse = CreateConversationResponse.fromActionResponse(response);
return conversationResponse;
});
return actionListener;
}

private ActionListener<MLRegisterModelGroupResponse> getMlRegisterModelGroupResponseActionListener(
ActionListener<MLRegisterModelGroupResponse> listener
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;

public class MachineLearningClientTest {

Expand Down Expand Up @@ -98,6 +99,9 @@ public class MachineLearningClientTest {
@Mock
MLRegisterAgentResponse registerAgentResponse;

@Mock
CreateConversationResponse createConversationResponse;

private String modekId = "test_model_id";
private MLModel mlModel;
private MLTask mlTask;
Expand Down Expand Up @@ -230,6 +234,11 @@ public void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentRespons
public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void createConversation(String name, ActionListener<CreateConversationResponse> listener) {
listener.onResponse(createConversationResponse);
}
};
}

Expand Down Expand Up @@ -502,4 +511,9 @@ public void getTool() {
public void listTools() {
assertEquals(toolMetadata, machineLearningClient.listTools().actionGet().get(0));
}

@Test
public void createConversation() {
assertEquals(createConversationResponse, machineLearningClient.createConversation("Conversation for a RAG pipeline").actionGet());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
import org.opensearch.ml.memory.action.conversation.CreateConversationAction;
import org.opensearch.ml.memory.action.conversation.CreateConversationRequest;
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
Expand Down Expand Up @@ -205,6 +208,9 @@ public class MachineLearningNodeClientTest {
@Mock
ActionListener<ToolMetadata> getToolActionListener;

@Mock
ActionListener<CreateConversationResponse> createConversationResponseActionListener;

@InjectMocks
MachineLearningNodeClient machineLearningNodeClient;

Expand Down Expand Up @@ -950,6 +956,26 @@ public void listTools() {
assertEquals("Use this tool to search general knowledge on wikipedia.", argumentCaptor.getValue().get(0).getDescription());
}

@Test
public void createConversation() {
String name = "Conversation for a RAG pipeline";
String conversationId = "conversationId";

doAnswer(invocation -> {
ActionListener<CreateConversationResponse> actionListener = invocation.getArgument(2);
CreateConversationResponse output = new CreateConversationResponse(conversationId);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(CreateConversationAction.INSTANCE), any(), any());

ArgumentCaptor<CreateConversationResponse> argumentCaptor = ArgumentCaptor.forClass(CreateConversationResponse.class);
machineLearningNodeClient.createConversation(name, createConversationResponseActionListener);

verify(client).execute(eq(CreateConversationAction.INSTANCE), isA(CreateConversationRequest.class), any());
verify(createConversationResponseActionListener).onResponse(argumentCaptor.capture());
assertEquals(conversationId, argumentCaptor.getValue().getId());
}

private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,21 @@
*/
package org.opensearch.ml.memory.action.conversation;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;

import lombok.AllArgsConstructor;

Expand Down Expand Up @@ -67,4 +73,20 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
return builder;
}

public static CreateConversationResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof MLCreateConnectorResponse) {
return (CreateConversationResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new CreateConversationResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into CreateConversationResponse", e);
}

}

}

0 comments on commit 1e74bab

Please sign in to comment.