Skip to content

Commit

Permalink
Expose execute api for MLClient (#1540)
Browse files Browse the repository at this point in the history
* Expose execute api for MLClient

Signed-off-by: Jackie Han <[email protected]>

* unit test change

Signed-off-by: Jackie Han <[email protected]>

---------

Signed-off-by: Jackie Han <[email protected]>
  • Loading branch information
jackiehanyang authored Oct 23, 2023
1 parent e9e3834 commit 5d9324c
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.common.action.ActionFuture;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
Expand Down Expand Up @@ -299,4 +302,23 @@ default ActionFuture<MLRegisterModelGroupResponse> registerModelGroup(MLRegister
* @param listener a listener to be notified of the result
*/
void registerModelGroup(MLRegisterModelGroupInput mlRegisterModelGroupInput, ActionListener<MLRegisterModelGroupResponse> listener);

/**
* Execute an algorithm
* @param name algorithm function name
* @param input input
* @return the result future
*/
default ActionFuture<MLExecuteTaskResponse> execute(FunctionName name, Input input) {
PlainActionFuture<MLExecuteTaskResponse> actionFuture = PlainActionFuture.newFuture();
execute(name, input, actionFuture);
return actionFuture;
}

/**
* Execute an algorithm
* @param input an algorithm input
* @param listener a listener to be notified of the result
*/
void execute(FunctionName name, Input input, ActionListener<MLExecuteTaskResponse> listener);
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.output.MLOutput;
Expand All @@ -37,6 +38,9 @@
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.ml.common.transport.model.MLModelGetAction;
Expand Down Expand Up @@ -182,6 +186,19 @@ public void registerModelGroup(
client.execute(MLRegisterModelGroupAction.INSTANCE, mlRegisterModelGroupRequest, listener);
}

/**
* Execute an algorithm
*
* @param name function name
* @param input an algorithm input
* @param listener a listener to be notified of the result
*/
@Override
public void execute(FunctionName name, Input input, ActionListener<MLExecuteTaskResponse> listener) {
MLExecuteTaskRequest mlExecuteTaskRequest = new MLExecuteTaskRequest(name, input);
client.execute(MLExecuteTaskAction.INSTANCE, mlExecuteTaskRequest, listener);
}

@Override
public void getTask(String taskId, ActionListener<MLTask> listener) {
MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder().taskId(taskId).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import static org.opensearch.ml.common.input.Constants.KMEANS;
import static org.opensearch.ml.common.input.Constants.TRAIN;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
Expand All @@ -32,7 +33,9 @@
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.execute.metricscorrelation.MetricsCorrelationInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
Expand All @@ -42,6 +45,7 @@
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
Expand Down Expand Up @@ -81,6 +85,9 @@ public class MachineLearningClientTest {
@Mock
MLRegisterModelGroupResponse registerModelGroupResponse;

@Mock
MLExecuteTaskResponse mlExecuteTaskResponse;

private String modekId = "test_model_id";
private MLModel mlModel;
private MLTask mlTask;
Expand Down Expand Up @@ -161,6 +168,11 @@ public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, Actio
listener.onResponse(createConnectorResponse);
}

@Override
public void execute(FunctionName name, Input input, ActionListener<MLExecuteTaskResponse> listener) {
listener.onResponse(mlExecuteTaskResponse);
}

public void registerModelGroup(
MLRegisterModelGroupInput mlRegisterModelGroupInput,
ActionListener<MLRegisterModelGroupResponse> listener
Expand Down Expand Up @@ -354,4 +366,58 @@ public void createConnector() {

assertEquals(createConnectorResponse, machineLearningClient.createConnector(mlCreateConnectorInput).actionGet());
}

@Test
public void executeMetricsCorrelation() {
List<float[]> inputData = new ArrayList<>(
Arrays
.asList(
new float[] {
0.89451003f,
4.2006273f,
0.3697659f,
2.2458954f,
-4.671612f,
-1.5076426f,
1.635445f,
-1.1394824f,
-0.7503817f,
0.98424894f,
-0.38896716f,
1.0328646f,
1.9543738f,
-0.5236269f,
0.14298044f,
3.2963762f,
8.1641035f,
5.717064f,
7.4869685f,
2.5987444f,
11.018798f,
9.151356f,
5.7354255f,
6.862203f,
3.0524514f,
4.431755f,
5.1481285f,
7.9548607f,
7.4519925f,
6.09533f,
7.634116f,
8.898271f,
3.898491f,
9.447067f,
8.197385f,
5.8284273f,
5.804283f,
7.089733f,
9.140584f }
)
);
Input metricsCorrelationInput = MetricsCorrelationInput.builder().inputData(inputData).build();
assertEquals(
mlExecuteTaskResponse,
machineLearningClient.execute(FunctionName.METRICS_CORRELATION, metricsCorrelationInput).actionGet()
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.client;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Answers.RETURNS_DEEP_STUBS;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
Expand All @@ -22,6 +23,7 @@
import static org.opensearch.ml.common.input.Constants.TRAINANDPREDICT;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
Expand Down Expand Up @@ -58,13 +60,19 @@
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.execute.metricscorrelation.MetricsCorrelationInput;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.common.output.MLTrainingOutput;
import org.opensearch.ml.common.output.Output;
import org.opensearch.ml.common.output.execute.metrics_correlation.MCorrModelTensor;
import org.opensearch.ml.common.output.execute.metrics_correlation.MCorrModelTensors;
import org.opensearch.ml.common.output.execute.metrics_correlation.MetricsCorrelationOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
Expand All @@ -73,6 +81,9 @@
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.ml.common.transport.model.MLModelGetAction;
Expand Down Expand Up @@ -152,6 +163,9 @@ public class MachineLearningNodeClientTest {
@Mock
ActionListener<MLRegisterModelGroupResponse> registerModelGroupResponseActionListener;

@Mock
ActionListener<MLExecuteTaskResponse> executeTaskResponseActionListener;

@InjectMocks
MachineLearningNodeClient machineLearningNodeClient;

Expand Down Expand Up @@ -676,6 +690,85 @@ public void createConnector() {

}

@Test
public void executeMetricsCorrelation() {
Output metricsCorrelationOutput;
List<MCorrModelTensors> outputs = new ArrayList<>();
MCorrModelTensor mCorrModelTensor = MCorrModelTensor
.builder()
.event_pattern(new float[] { 1.0f, 2.0f, 3.0f })
.event_window(new float[] { 4.0f, 5.0f, 6.0f })
.suspected_metrics(new long[] { 1, 2 })
.build();
List<MCorrModelTensor> mlModelTensors = Arrays.asList(mCorrModelTensor);
MCorrModelTensors modelTensors = MCorrModelTensors.builder().mCorrModelTensors(mlModelTensors).build();
outputs.add(modelTensors);
metricsCorrelationOutput = MetricsCorrelationOutput.builder().modelOutput(outputs).build();

doAnswer(invocation -> {
ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2);
MLExecuteTaskResponse output = new MLExecuteTaskResponse(FunctionName.METRICS_CORRELATION, metricsCorrelationOutput);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any());

ArgumentCaptor<MLExecuteTaskResponse> argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskResponse.class);

List<float[]> inputData = new ArrayList<>(
Arrays
.asList(
new float[] {
0.89451003f,
4.2006273f,
0.3697659f,
2.2458954f,
-4.671612f,
-1.5076426f,
1.635445f,
-1.1394824f,
-0.7503817f,
0.98424894f,
-0.38896716f,
1.0328646f,
1.9543738f,
-0.5236269f,
0.14298044f,
3.2963762f,
8.1641035f,
5.717064f,
7.4869685f,
2.5987444f,
11.018798f,
9.151356f,
5.7354255f,
6.862203f,
3.0524514f,
4.431755f,
5.1481285f,
7.9548607f,
7.4519925f,
6.09533f,
7.634116f,
8.898271f,
3.898491f,
9.447067f,
8.197385f,
5.8284273f,
5.804283f,
7.089733f,
9.140584f }
)
);
Input metricsCorrelationInput = MetricsCorrelationInput.builder().inputData(inputData).build();

machineLearningNodeClient.execute(FunctionName.METRICS_CORRELATION, metricsCorrelationInput, executeTaskResponseActionListener);

verify(client).execute(eq(MLExecuteTaskAction.INSTANCE), isA(MLExecuteTaskRequest.class), any());
verify(executeTaskResponseActionListener).onResponse(argumentCaptor.capture());
assertEquals(FunctionName.METRICS_CORRELATION, argumentCaptor.getValue().getFunctionName());
assertTrue(argumentCaptor.getValue().getOutput() instanceof MetricsCorrelationOutput);
}

private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);

Expand Down

0 comments on commit 5d9324c

Please sign in to comment.