Skip to content

Commit

Permalink
Added create connector API for MLClient (#1437)
Browse files Browse the repository at this point in the history
* Adds create connector API for MLClient

Signed-off-by: Owais Kazi <[email protected]>

* Addressed PR comments

Signed-off-by: Owais Kazi <[email protected]>

* Addressed PR Comments

Signed-off-by: Owais Kazi <[email protected]>

---------

Signed-off-by: Owais Kazi <[email protected]>
  • Loading branch information
owaiskazi19 authored Oct 12, 2023
1 parent 89f9b85 commit d265fc7
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
Expand Down Expand Up @@ -267,4 +269,17 @@ default ActionFuture<MLDeployModelResponse> deploy(String modelId) {
* @param listener a listener to be notified of the result
*/
void deploy(String modelId, ActionListener<MLDeployModelResponse> listener);

/**
* Create connector for remote model
* @param mlCreateConnectorInput Create Connector Input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/extensibility/connectors/
* @return the result future
*/
default ActionFuture<MLCreateConnectorResponse> createConnector(MLCreateConnectorInput mlCreateConnectorInput) {
PlainActionFuture<MLCreateConnectorResponse> actionFuture = PlainActionFuture.newFuture();
createConnector(mlCreateConnectorInput, actionFuture);
return actionFuture;
}

void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener<MLCreateConnectorResponse> listener);
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
import org.opensearch.ml.common.model.MetricsCorrelationModelConfig;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelInput;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
Expand Down Expand Up @@ -228,6 +232,12 @@ public void deploy(String modelId, ActionListener<MLDeployModelResponse> listene
}));
}

@Override
public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener<MLCreateConnectorResponse> listener) {
MLCreateConnectorRequest createConnectorRequest = new MLCreateConnectorRequest(mlCreateConnectorInput);
client.execute(MLCreateConnectorAction.INSTANCE, createConnectorRequest, listener);
}

private ActionListener<MLTaskResponse> getMlPredictionTaskResponseActionListener(ActionListener<MLOutput> listener) {
ActionListener<MLTaskResponse> internalListener = ActionListener.wrap(predictionResponse -> {
listener.onResponse(predictionResponse.getOutput());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.input.MLInput;
Expand All @@ -27,6 +28,8 @@
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.output.MLTrainingOutput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
Expand Down Expand Up @@ -74,6 +77,9 @@ public class MachineLearningClientTest {
@Mock
MLDeployModelResponse deployModelResponse;

@Mock
MLCreateConnectorResponse createConnectorResponse;

private String modekId = "test_model_id";
private MLModel mlModel;
private MLTask mlTask;
Expand Down Expand Up @@ -158,6 +164,11 @@ public void register(MLRegisterModelInput mlInput, ActionListener<MLRegisterMode
public void deploy(String modelId, ActionListener<MLDeployModelResponse> listener) {
listener.onResponse(deployModelResponse);
}

@Override
public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener<MLCreateConnectorResponse> listener) {
listener.onResponse(createConnectorResponse);
}
};
}

Expand Down Expand Up @@ -304,4 +315,26 @@ public void register() {
public void deploy() {
assertEquals(deployModelResponse, machineLearningClient.deploy("modelId").actionGet());
}

@Test
public void createConnector() {
Map<String, String> params = Map.ofEntries(Map.entry("endpoint", "endpoint"), Map.entry("temp", "7"));
Map<String, String> credentials = Map.ofEntries(Map.entry("key1", "key1"), Map.entry("key2", "key2"));

MLCreateConnectorInput mlCreateConnectorInput = MLCreateConnectorInput.builder()
.name("test")
.description("description")
.version("testModelVersion")
.protocol("testProtocol")
.parameters(params)
.credential(credentials)
.actions(null)
.backendRoles(null)
.addAllBackendRoles(false)
.access(AccessMode.from("private"))
.dryRun(false)
.build();

assertEquals(createConnectorResponse, machineLearningClient.createConnector(mlCreateConnectorInput).actionGet());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.MLInputDataset;
Expand All @@ -42,6 +43,10 @@
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.common.output.MLTrainingOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
Expand Down Expand Up @@ -77,6 +82,8 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Arrays;
import java.util.List;

import static org.junit.Assert.assertEquals;
import static org.mockito.Answers.RETURNS_DEEP_STUBS;
Expand Down Expand Up @@ -121,10 +128,13 @@ public class MachineLearningNodeClientTest {
ActionListener<SearchResponse> searchTaskActionListener;

@Mock
ActionListener<MLRegisterModelResponse> RegisterModelActionListener;
ActionListener<MLRegisterModelResponse> registerModelActionListener;

@Mock
ActionListener<MLDeployModelResponse> DeployModelActionListener;
ActionListener<MLDeployModelResponse> deployModelActionListener;

@Mock
ActionListener<MLCreateConnectorResponse> createConnectorActionListener;

@InjectMocks
MachineLearningNodeClient machineLearningNodeClient;
Expand Down Expand Up @@ -601,10 +611,10 @@ public void register() {
.deployModel(true)
.modelNodeIds(new String[]{"modelNodeIds" })
.build();
machineLearningNodeClient.register(mlInput, RegisterModelActionListener);
machineLearningNodeClient.register(mlInput, registerModelActionListener);

verify(client).execute(eq(MLRegisterModelAction.INSTANCE), isA(MLRegisterModelRequest.class), any());
verify(RegisterModelActionListener).onResponse(argumentCaptor.capture());
verify(registerModelActionListener).onResponse(argumentCaptor.capture());
assertEquals(taskId, (argumentCaptor.getValue()).getTaskId());
assertEquals(status, (argumentCaptor.getValue()).getStatus());
}
Expand All @@ -615,7 +625,6 @@ public void deploy() {
String status = MLTaskState.CREATED.name();
MLTaskType mlTaskType = MLTaskType.DEPLOY_MODEL;
String modelId = "modelId";
FunctionName functionName = FunctionName.KMEANS;
doAnswer(invocation -> {
ActionListener<MLDeployModelResponse> actionListener = invocation.getArgument(2);
MLDeployModelResponse output = new MLDeployModelResponse(taskId, mlTaskType, status);
Expand All @@ -624,14 +633,55 @@ public void deploy() {
}).when(client).execute(eq(MLDeployModelAction.INSTANCE), any(), any());

ArgumentCaptor<MLDeployModelResponse> argumentCaptor = ArgumentCaptor.forClass(MLDeployModelResponse.class);
machineLearningNodeClient.deploy(modelId, DeployModelActionListener);
machineLearningNodeClient.deploy(modelId, deployModelActionListener);

verify(client).execute(eq(MLDeployModelAction.INSTANCE), isA(MLDeployModelRequest.class), any());
verify(DeployModelActionListener).onResponse(argumentCaptor.capture());
verify(deployModelActionListener).onResponse(argumentCaptor.capture());
assertEquals(taskId, (argumentCaptor.getValue()).getTaskId());
assertEquals(status, (argumentCaptor.getValue()).getStatus());
}

@Test
public void createConnector() {


String connectorId = "connectorId";

doAnswer(invocation -> {
ActionListener<MLCreateConnectorResponse> actionListener = invocation.getArgument(2);
MLCreateConnectorResponse output = new MLCreateConnectorResponse(connectorId);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLCreateConnectorAction.INSTANCE), any(), any());

ArgumentCaptor<MLCreateConnectorResponse> argumentCaptor = ArgumentCaptor.forClass(MLCreateConnectorResponse.class);

Map<String, String> params = Map.ofEntries(Map.entry("endpoint", "endpoint"), Map.entry("temp", "7"));
Map<String, String> credentials = Map.ofEntries(Map.entry("key1", "value1"), Map.entry("key2", "value2"));
List<String> backendRoles = Arrays.asList("IT", "HR");

MLCreateConnectorInput mlCreateConnectorInput = MLCreateConnectorInput.builder()
.name("test")
.description("description")
.version("testModelVersion")
.protocol("testProtocol")
.parameters(params)
.credential(credentials)
.actions(null)
.backendRoles(backendRoles)
.addAllBackendRoles(false)
.access(AccessMode.from("private"))
.dryRun(false)
.build();

machineLearningNodeClient.createConnector(mlCreateConnectorInput, createConnectorActionListener);

verify(client).execute(eq(MLCreateConnectorAction.INSTANCE), isA(MLCreateConnectorRequest.class), any());
verify(createConnectorActionListener).onResponse(argumentCaptor.capture());
assertEquals(connectorId, (argumentCaptor.getValue()).getConnectorId());

}

private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);

Expand Down

0 comments on commit d265fc7

Please sign in to comment.