Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Expose execute api for MLClient #1541

Merged
merged 1 commit into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.common.action.ActionFuture;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
Expand Down Expand Up @@ -299,4 +302,23 @@ default ActionFuture<MLRegisterModelGroupResponse> registerModelGroup(MLRegister
* @param listener a listener to be notified of the result
*/
void registerModelGroup(MLRegisterModelGroupInput mlRegisterModelGroupInput, ActionListener<MLRegisterModelGroupResponse> listener);

/**
* Execute an algorithm
* @param name algorithm function name
* @param input input
* @return the result future
*/
default ActionFuture<MLExecuteTaskResponse> execute(FunctionName name, Input input) {
PlainActionFuture<MLExecuteTaskResponse> actionFuture = PlainActionFuture.newFuture();
execute(name, input, actionFuture);
return actionFuture;
}

/**
* Execute an algorithm
* @param input an algorithm input
* @param listener a listener to be notified of the result
*/
void execute(FunctionName name, Input input, ActionListener<MLExecuteTaskResponse> listener);
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.output.MLOutput;
Expand All @@ -37,6 +38,9 @@
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.ml.common.transport.model.MLModelGetAction;
Expand Down Expand Up @@ -182,6 +186,19 @@ public void registerModelGroup(
client.execute(MLRegisterModelGroupAction.INSTANCE, mlRegisterModelGroupRequest, listener);
}

/**
* Execute an algorithm
*
* @param name function name
* @param input an algorithm input
* @param listener a listener to be notified of the result
*/
@Override
public void execute(FunctionName name, Input input, ActionListener<MLExecuteTaskResponse> listener) {
MLExecuteTaskRequest mlExecuteTaskRequest = new MLExecuteTaskRequest(name, input);
client.execute(MLExecuteTaskAction.INSTANCE, mlExecuteTaskRequest, listener);
}

@Override
public void getTask(String taskId, ActionListener<MLTask> listener) {
MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder().taskId(taskId).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import static org.opensearch.ml.common.input.Constants.KMEANS;
import static org.opensearch.ml.common.input.Constants.TRAIN;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
Expand All @@ -32,7 +33,9 @@
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.execute.metricscorrelation.MetricsCorrelationInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
Expand All @@ -42,6 +45,7 @@
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
Expand Down Expand Up @@ -81,6 +85,9 @@ public class MachineLearningClientTest {
@Mock
MLRegisterModelGroupResponse registerModelGroupResponse;

@Mock
MLExecuteTaskResponse mlExecuteTaskResponse;

private String modekId = "test_model_id";
private MLModel mlModel;
private MLTask mlTask;
Expand Down Expand Up @@ -161,6 +168,11 @@ public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, Actio
listener.onResponse(createConnectorResponse);
}

@Override
public void execute(FunctionName name, Input input, ActionListener<MLExecuteTaskResponse> listener) {
listener.onResponse(mlExecuteTaskResponse);
}

public void registerModelGroup(
MLRegisterModelGroupInput mlRegisterModelGroupInput,
ActionListener<MLRegisterModelGroupResponse> listener
Expand Down Expand Up @@ -354,4 +366,58 @@ public void createConnector() {

assertEquals(createConnectorResponse, machineLearningClient.createConnector(mlCreateConnectorInput).actionGet());
}

@Test
public void executeMetricsCorrelation() {
List<float[]> inputData = new ArrayList<>(
Arrays
.asList(
new float[] {
0.89451003f,
4.2006273f,
0.3697659f,
2.2458954f,
-4.671612f,
-1.5076426f,
1.635445f,
-1.1394824f,
-0.7503817f,
0.98424894f,
-0.38896716f,
1.0328646f,
1.9543738f,
-0.5236269f,
0.14298044f,
3.2963762f,
8.1641035f,
5.717064f,
7.4869685f,
2.5987444f,
11.018798f,
9.151356f,
5.7354255f,
6.862203f,
3.0524514f,
4.431755f,
5.1481285f,
7.9548607f,
7.4519925f,
6.09533f,
7.634116f,
8.898271f,
3.898491f,
9.447067f,
8.197385f,
5.8284273f,
5.804283f,
7.089733f,
9.140584f }
)
);
Input metricsCorrelationInput = MetricsCorrelationInput.builder().inputData(inputData).build();
assertEquals(
mlExecuteTaskResponse,
machineLearningClient.execute(FunctionName.METRICS_CORRELATION, metricsCorrelationInput).actionGet()
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.client;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Answers.RETURNS_DEEP_STUBS;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
Expand All @@ -22,6 +23,7 @@
import static org.opensearch.ml.common.input.Constants.TRAINANDPREDICT;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
Expand Down Expand Up @@ -58,13 +60,19 @@
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.execute.metricscorrelation.MetricsCorrelationInput;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.common.output.MLTrainingOutput;
import org.opensearch.ml.common.output.Output;
import org.opensearch.ml.common.output.execute.metrics_correlation.MCorrModelTensor;
import org.opensearch.ml.common.output.execute.metrics_correlation.MCorrModelTensors;
import org.opensearch.ml.common.output.execute.metrics_correlation.MetricsCorrelationOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
Expand All @@ -73,6 +81,9 @@
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.ml.common.transport.model.MLModelGetAction;
Expand Down Expand Up @@ -152,6 +163,9 @@ public class MachineLearningNodeClientTest {
@Mock
ActionListener<MLRegisterModelGroupResponse> registerModelGroupResponseActionListener;

@Mock
ActionListener<MLExecuteTaskResponse> executeTaskResponseActionListener;

@InjectMocks
MachineLearningNodeClient machineLearningNodeClient;

Expand Down Expand Up @@ -676,6 +690,85 @@ public void createConnector() {

}

@Test
public void executeMetricsCorrelation() {
Output metricsCorrelationOutput;
List<MCorrModelTensors> outputs = new ArrayList<>();
MCorrModelTensor mCorrModelTensor = MCorrModelTensor
.builder()
.event_pattern(new float[] { 1.0f, 2.0f, 3.0f })
.event_window(new float[] { 4.0f, 5.0f, 6.0f })
.suspected_metrics(new long[] { 1, 2 })
.build();
List<MCorrModelTensor> mlModelTensors = Arrays.asList(mCorrModelTensor);
MCorrModelTensors modelTensors = MCorrModelTensors.builder().mCorrModelTensors(mlModelTensors).build();
outputs.add(modelTensors);
metricsCorrelationOutput = MetricsCorrelationOutput.builder().modelOutput(outputs).build();

doAnswer(invocation -> {
ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2);
MLExecuteTaskResponse output = new MLExecuteTaskResponse(FunctionName.METRICS_CORRELATION, metricsCorrelationOutput);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any());

ArgumentCaptor<MLExecuteTaskResponse> argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskResponse.class);

List<float[]> inputData = new ArrayList<>(
Arrays
.asList(
new float[] {
0.89451003f,
4.2006273f,
0.3697659f,
2.2458954f,
-4.671612f,
-1.5076426f,
1.635445f,
-1.1394824f,
-0.7503817f,
0.98424894f,
-0.38896716f,
1.0328646f,
1.9543738f,
-0.5236269f,
0.14298044f,
3.2963762f,
8.1641035f,
5.717064f,
7.4869685f,
2.5987444f,
11.018798f,
9.151356f,
5.7354255f,
6.862203f,
3.0524514f,
4.431755f,
5.1481285f,
7.9548607f,
7.4519925f,
6.09533f,
7.634116f,
8.898271f,
3.898491f,
9.447067f,
8.197385f,
5.8284273f,
5.804283f,
7.089733f,
9.140584f }
)
);
Input metricsCorrelationInput = MetricsCorrelationInput.builder().inputData(inputData).build();

machineLearningNodeClient.execute(FunctionName.METRICS_CORRELATION, metricsCorrelationInput, executeTaskResponseActionListener);

verify(client).execute(eq(MLExecuteTaskAction.INSTANCE), isA(MLExecuteTaskRequest.class), any());
verify(executeTaskResponseActionListener).onResponse(argumentCaptor.capture());
assertEquals(FunctionName.METRICS_CORRELATION, argumentCaptor.getValue().getFunctionName());
assertTrue(argumentCaptor.getValue().getOutput() instanceof MetricsCorrelationOutput);
}

private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);

Expand Down
Loading