From 1a2e623464b111a808a66a1ab20a4650b038a980 Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Fri, 2 Aug 2024 17:51:26 -0500 Subject: [PATCH] add batch predict to task Signed-off-by: Bhavana Ramaram --- .../org/opensearch/ml/common/CommonValue.java | 6 +- .../java/org/opensearch/ml/common/MLTask.java | 47 ++++- .../org/opensearch/ml/common/MLTaskType.java | 1 + .../ml/common/connector/ConnectorAction.java | 4 +- .../action/tasks/GetTaskTransportAction.java | 166 +++++++++++++++++- .../ml/task/MLPredictTaskRunner.java | 68 ++++++- .../tasks/GetTaskTransportActionTests.java | 36 +++- .../ml/task/MLPredictTaskRunnerTests.java | 85 +++++++++ 8 files changed, 402 insertions(+), 11 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index b4a1a665a5..70589d7cdd 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -65,7 +65,7 @@ public class CommonValue { public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2; public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 11; public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector"; - public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2; + public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 3; public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 3; public static final String ML_CONFIG_INDEX = ".plugins-ml-config"; public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 3; @@ -363,6 +363,10 @@ public class CommonValue { + "\" : {\"type\" : \"boolean\"}, \n" + USER_FIELD_MAPPING + " }\n" + + "}" + + MLTask.TRANSFORM_JOB_FIELD + + "\" : {\"type\": \"flat_object\"}\n" + + " }\n" + "}"; public static final String ML_CONNECTOR_INDEX_MAPPING = "{\n" diff --git a/common/src/main/java/org/opensearch/ml/common/MLTask.java b/common/src/main/java/org/opensearch/ml/common/MLTask.java index 229bba5771..813c72f2e4 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLTask.java +++ b/common/src/main/java/org/opensearch/ml/common/MLTask.java @@ -9,6 +9,7 @@ import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.Setter; +import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -17,15 +18,22 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.dataset.MLInputDataType; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Map; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.USER; +import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; +import static org.opensearch.ml.common.utils.StringUtils.gson; @Getter @EqualsAndHashCode @@ -44,6 +52,8 @@ public class MLTask implements ToXContentObject, Writeable { public static final String LAST_UPDATE_TIME_FIELD = "last_update_time"; public static final String ERROR_FIELD = "error"; public static final String IS_ASYNC_TASK_FIELD = "is_async"; + public static final String TRANSFORM_JOB_FIELD = "transform_job"; + public static final Version MINIMAL_SUPPORTED_VERSION_FOR_BATCH_TRANSFORM_JOB = CommonValue.VERSION_2_16_0; @Setter private String taskId; @@ -65,6 +75,8 @@ public class MLTask implements ToXContentObject, Writeable { private String error; private User user; // TODO: support document level access control later private boolean async; + @Setter + private Map transformJob; @Builder(toBuilder = true) public MLTask( @@ -81,7 +93,8 @@ public MLTask( Instant lastUpdateTime, String error, User user, - boolean async + boolean async, + Map transformJob ) { this.taskId = taskId; this.modelId = modelId; @@ -97,9 +110,11 @@ public MLTask( this.error = error; this.user = user; this.async = async; + this.transformJob = transformJob; } public MLTask(StreamInput input) throws IOException { + Version streamInputVersion = input.getVersion(); this.taskId = input.readOptionalString(); this.modelId = input.readOptionalString(); this.taskType = input.readEnum(MLTaskType.class); @@ -122,10 +137,17 @@ public MLTask(StreamInput input) throws IOException { this.user = null; } this.async = input.readBoolean(); + if (streamInputVersion.onOrAfter(MLTask.MINIMAL_SUPPORTED_VERSION_FOR_BATCH_TRANSFORM_JOB)) { + if (input.readBoolean()) { + String mapStr = input.readString(); + this.transformJob = gson.fromJson(mapStr, Map.class); + } + } } @Override public void writeTo(StreamOutput out) throws IOException { + Version streamOutputVersion = out.getVersion(); out.writeOptionalString(taskId); out.writeOptionalString(modelId); out.writeEnum(taskType); @@ -149,6 +171,21 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } out.writeBoolean(async); + if (streamOutputVersion.onOrAfter(MLTask.MINIMAL_SUPPORTED_VERSION_FOR_BATCH_TRANSFORM_JOB)) { + if (transformJob != null) { + out.writeBoolean(true); + try { + AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + out.writeString(gson.toJson(transformJob)); + return null; + }); + } catch (PrivilegedActionException e) { + throw new RuntimeException(e); + } + } else { + out.writeBoolean(false); + } + } } @Override @@ -194,6 +231,9 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params builder.field(USER, user); } builder.field(IS_ASYNC_TASK_FIELD, async); + if (transformJob != null) { + builder.field(TRANSFORM_JOB_FIELD, transformJob); + } return builder.endObject(); } @@ -217,6 +257,7 @@ public static MLTask parse(XContentParser parser) throws IOException { String error = null; User user = null; boolean async = false; + Map transformJob = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -274,6 +315,9 @@ public static MLTask parse(XContentParser parser) throws IOException { case IS_ASYNC_TASK_FIELD: async = parser.booleanValue(); break; + case TRANSFORM_JOB_FIELD: + transformJob = parser.map(); + break; default: parser.skipChildren(); break; @@ -294,6 +338,7 @@ public static MLTask parse(XContentParser parser) throws IOException { .error(error) .user(user) .async(async) + .transformJob(transformJob) .build(); } } diff --git a/common/src/main/java/org/opensearch/ml/common/MLTaskType.java b/common/src/main/java/org/opensearch/ml/common/MLTaskType.java index db2f67f369..fee55bc712 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLTaskType.java +++ b/common/src/main/java/org/opensearch/ml/common/MLTaskType.java @@ -8,6 +8,7 @@ public enum MLTaskType { TRAINING, PREDICTION, + BATCH_PREDICTION, TRAINING_AND_PREDICTION, EXECUTION, @Deprecated diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java index 9be290d126..4d4f442d3e 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java @@ -187,7 +187,9 @@ public static ConnectorAction parse(XContentParser parser) throws IOException { public enum ActionType { PREDICT, EXECUTE, - BATCH_PREDICT; + BATCH_PREDICT, + CANCEL_BATCH, + BATCH_STATUS; public static ActionType from(String value) { try { diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java index 88c05f71c1..c17009afdd 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java @@ -6,15 +6,27 @@ package org.opensearch.ml.action.tasks; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; +import static org.opensearch.ml.common.MLTask.STATE_FIELD; +import static org.opensearch.ml.common.MLTask.TRANSFORM_JOB_FIELD; +import static org.opensearch.ml.common.MLTaskState.COMPLETED; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.BATCH_STATUS; +import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.*; +import static org.opensearch.ml.utils.MLExceptionUtils.logException; import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; +import java.util.HashMap; +import java.util.Map; + import org.opensearch.OpenSearchStatusException; +import org.opensearch.ResourceNotFoundException; import org.opensearch.action.ActionRequest; import org.opensearch.action.get.GetRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; @@ -22,11 +34,28 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; +import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorAction.ActionType; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.exception.MLResourceNotFoundException; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.task.MLTaskGetAction; import org.opensearch.ml.common.transport.task.MLTaskGetRequest; import org.opensearch.ml.common.transport.task.MLTaskGetResponse; +import org.opensearch.ml.engine.MLEngineClassLoader; +import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.model.MLModelCacheHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.task.MLTaskManager; +import org.opensearch.script.ScriptService; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -38,16 +67,38 @@ public class GetTaskTransportAction extends HandledTransportAction actionListener) { + + Map transformJob = mlTask.getTransformJob(); + Map parameters = new HashMap<>(); + for (Map.Entry entry : transformJob.entrySet()) { + if (entry.getValue() instanceof String) { + parameters.put(entry.getKey(), (String) entry.getValue()); + } else { + log.debug("Value for key " + entry.getKey() + " is not a String"); + } + } + + if (parameters.containsKey("TransformJobArn") && parameters.get("TransformJobArn") != null) { + String jobArn = parameters.get("TransformJobArn"); + String transformJobName = jobArn.substring(jobArn.lastIndexOf("/") + 1); + parameters.put("TransformJobName", transformJobName); + parameters.remove("TransformJobArn"); + } + + RemoteInferenceInputDataSet inferenceInputDataSet = new RemoteInferenceInputDataSet(parameters, ActionType.BATCH_PREDICT); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inferenceInputDataSet).build(); + String modelId = mlTask.getModelId(); + + try { + mlModelManager.getModel(modelId, null, null, ActionListener.wrap(model -> { + if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) { + ActionListener listener = ActionListener.wrap(connector -> { + connector.decrypt(BATCH_STATUS.name(), (credential) -> encryptor.decrypt(credential)); + RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader + .initInstance(connector.getProtocol(), connector, Connector.class); + connectorExecutor.setScriptService(scriptService); + connectorExecutor.setClusterService(clusterService); + connectorExecutor.setClient(client); + connectorExecutor.setXContentRegistry(xContentRegistry); + connectorExecutor.executeAction(BATCH_STATUS.name(), mlInput, ActionListener.wrap(taskResponse -> { + processTaskResponse(mlTask, taskId, taskResponse, transformJob, actionListener); + }, e -> { actionListener.onFailure(e); })); + }, e -> { + log.error("Failed to get connector " + model.getConnectorId(), e); + actionListener.onFailure(e); + }); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + connectorAccessControlHelper + .getConnector(client, model.getConnectorId(), ActionListener.runBefore(listener, threadContext::restore)); + } + } else { + actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId())); + } + }, e -> { + log.error("Failed to retrieve the ML model with the given ID", e); + actionListener.onFailure(e); + })); + } catch (Exception e) { + // fetch the connector + log.error("Unable to fetch status for ml task ", e); + } + } + + private void processTaskResponse( + MLTask mlTask, + String taskId, + MLTaskResponse taskResponse, + Map transformJob, + ActionListener actionListener + ) { + try { + ModelTensorOutput tensorOutput = (ModelTensorOutput) taskResponse.getOutput(); + if (tensorOutput != null && tensorOutput.getMlModelOutputs() != null && !tensorOutput.getMlModelOutputs().isEmpty()) { + ModelTensors modelOutput = tensorOutput.getMlModelOutputs().get(0); + if (modelOutput.getMlModelTensors() != null && !modelOutput.getMlModelTensors().isEmpty()) { + Map transformJobStatus = (Map) modelOutput.getMlModelTensors().get(0).getDataAsMap(); + if (transformJobStatus != null) { + transformJob.putAll(transformJobStatus); + Map updatedTask = new HashMap<>(); + updatedTask.put(TRANSFORM_JOB_FIELD, transformJob); + + if ((transformJob.containsKey("status") && transformJob.get("status").equals("completed")) + || (transformJob.containsKey("TransformJobStatus") + && transformJob.get("TransformJobStatus").equals("Completed"))) { + updatedTask.put(STATE_FIELD, COMPLETED); + mlTask.setState(COMPLETED); + } + mlTaskManager.updateMLTaskDirectly(taskId, updatedTask, ActionListener.wrap(response -> { + actionListener.onResponse(MLTaskGetResponse.builder().mlTask(mlTask).build()); + }, e -> { + logException("Failed to update task for batch predict model", e, log); + actionListener.onFailure(e); + })); + } else { + log.debug("Transform job status is null."); + actionListener.onFailure(new ResourceNotFoundException("Couldn't fetch status of the transform job")); + } + } else { + log.debug("ML Model Tensors are null or empty."); + actionListener.onFailure(new ResourceNotFoundException("Couldn't fetch status of the transform job")); + } + } else { + log.debug("ML Model Outputs are null or empty."); + actionListener.onFailure(new ResourceNotFoundException("Couldn't fetch status of the transform job")); + } + } catch (Exception e) { + log.error("Unable to fetch status for ml task ", e); + } } } diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 72c43bd58f..35150db8d9 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -14,9 +14,12 @@ import static org.opensearch.ml.plugin.MachineLearningPlugin.PREDICT_THREAD_POOL; import static org.opensearch.ml.plugin.MachineLearningPlugin.REMOTE_PREDICT_THREAD_POOL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE; +import static org.opensearch.ml.utils.MLExceptionUtils.logException; import java.time.Instant; import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; import java.util.UUID; import org.opensearch.OpenSearchException; @@ -55,6 +58,7 @@ import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLPredictionOutput; import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest; @@ -228,11 +232,18 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener dataFrameActionListener = ActionListener.wrap(dataSet -> { @@ -336,12 +346,60 @@ private void runPredict( if (mlInput.getAlgorithm() == FunctionName.REMOTE) { long startTime = System.nanoTime(); ActionListener trackPredictDurationListener = ActionListener.wrap(output -> { + if (output.getOutput() instanceof ModelTensorOutput) { validateOutputSchema(modelId, (ModelTensorOutput) output.getOutput()); } - handleAsyncMLTaskComplete(mlTask); - mlModelManager.trackPredictDuration(modelId, startTime); - internalListener.onResponse(output); + if (mlTask.getTaskType().equals(MLTaskType.BATCH_PREDICTION)) { + Map transformJob = new HashMap<>(); + ModelTensorOutput tensorOutput = (ModelTensorOutput) output.getOutput(); + if (tensorOutput != null + && tensorOutput.getMlModelOutputs() != null + && !tensorOutput.getMlModelOutputs().isEmpty()) { + ModelTensors modelOutput = tensorOutput.getMlModelOutputs().get(0); + if (modelOutput.getMlModelTensors() != null && !modelOutput.getMlModelTensors().isEmpty()) { + Map dataAsMap = (Map) modelOutput + .getMlModelTensors() + .get(0) + .getDataAsMap(); + if (dataAsMap != null + && (dataAsMap.containsKey("TransformJobArn") || dataAsMap.containsKey("id"))) { + transformJob.putAll(dataAsMap); + mlTask.setTransformJob(transformJob); + mlTask.setTaskId(null); + mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> { + String taskId = response.getId(); + mlTask.setTaskId(taskId); + MLPredictionOutput outputBuilder = MLPredictionOutput + .builder() + .taskId(taskId) + .status(MLTaskState.CREATED.name()) + .build(); + + MLTaskResponse predictOutput = MLTaskResponse.builder().output(outputBuilder).build(); + internalListener.onResponse(predictOutput); + }, e -> { + logException("Failed to create task for batch predict model", e, log); + internalListener.onFailure(e); + })); + } else { + log.debug("Batch transform job output from remote model did not return the job ID"); + internalListener + .onFailure(new ResourceNotFoundException("Unable to create batch transform job")); + } + } else { + log.debug("ML Model Tensors are null or empty."); + internalListener.onFailure(new ResourceNotFoundException("Unable to create batch transform job")); + } + } else { + log.debug("ML Model Outputs are null or empty."); + internalListener.onFailure(new ResourceNotFoundException("Unable to create batch transform job")); + } + } else { + handleAsyncMLTaskComplete(mlTask); + mlModelManager.trackPredictDuration(modelId, startTime); + internalListener.onResponse(output); + } }, e -> handlePredictFailure(mlTask, internalListener, e, false, modelId, actionName)); predictor.asyncPredict(mlInput, trackPredictDurationListener); } else { diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java index 83da0f8273..3e59258436 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java @@ -22,6 +22,7 @@ import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; @@ -29,6 +30,11 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.transport.task.MLTaskGetRequest; import org.opensearch.ml.common.transport.task.MLTaskGetResponse; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.task.MLTaskManager; +import org.opensearch.script.ScriptService; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -46,11 +52,26 @@ public class GetTaskTransportActionTests extends OpenSearchTestCase { @Mock TransportService transportService; + @Mock + private ClusterService clusterService; + @Mock + private ScriptService scriptService; + @Mock ActionFilters actionFilters; + @Mock + private ConnectorAccessControlHelper connectorAccessControlHelper; + + @Mock + private EncryptorImpl encryptor; @Mock ActionListener actionListener; + @Mock + private MLModelManager mlModelManager; + + @Mock + private MLTaskManager mlTaskManager; @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -64,7 +85,20 @@ public void setup() throws IOException { MockitoAnnotations.openMocks(this); mlTaskGetRequest = MLTaskGetRequest.builder().taskId("test_id").build(); - getTaskTransportAction = spy(new GetTaskTransportAction(transportService, actionFilters, client, xContentRegistry)); + getTaskTransportAction = spy( + new GetTaskTransportAction( + transportService, + actionFilters, + client, + xContentRegistry, + clusterService, + scriptService, + connectorAccessControlHelper, + encryptor, + mlTaskManager, + mlModelManager + ) + ); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index cbde703543..e09610140d 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -25,10 +25,12 @@ import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchStatusException; import org.opensearch.Version; import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexResponse; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; @@ -49,11 +51,13 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.PredictMode; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.SearchQueryInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.parameter.rcf.BatchRCFParams; import org.opensearch.ml.common.output.MLPredictionOutput; @@ -412,6 +416,87 @@ public void testValidateModelTensorOutputSuccess() { taskRunner.validateOutputSchema("testId", modelTensorOutput); } + public void testValidateBatchPredictionSuccess() throws IOException { + setupMocks(true, false, false, false); + RemoteInferenceInputDataSet remoteInputDataSet = RemoteInferenceInputDataSet.builder().predictMode(PredictMode.BATCH).build(); + MLPredictionTaskRequest remoteInputRequest = MLPredictionTaskRequest + .builder() + .modelId("test_model") + .mlInput(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInputDataSet).build()) + .build(); + Predictable predictor = mock(Predictable.class); + when(predictor.isModelReady()).thenReturn(true); + ModelTensor modelTensor = ModelTensor + .builder() + .name("response") + .dataAsMap(Map.of("TransformJobArn", "arn:aws:sagemaker:us-east-1:802041417063:transform-job/batch-transform-01")) + .build(); + Map modelInterface = Map + .of( + "output", + "{\"properties\":{\"inference_results\":{\"description\":\"This is a test description field\"," + "\"type\":\"array\"}}}" + ); + ModelTensorOutput modelTensorOutput = ModelTensorOutput + .builder() + .mlModelOutputs(List.of(ModelTensors.builder().mlModelTensors(List.of(modelTensor)).build())) + .build(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(MLTaskResponse.builder().output(modelTensorOutput).build()); + return null; + }).when(predictor).asyncPredict(any(), any()); + + IndexResponse indexResponse = mock(IndexResponse.class); + when(indexResponse.getId()).thenReturn("mockTaskId"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(indexResponse); + return null; + }).when(mlTaskManager).createMLTask(any(MLTask.class), Mockito.isA(ActionListener.class)); + + when(mlModelManager.getModelInterface(any())).thenReturn(modelInterface); + + when(mlModelManager.getPredictor(anyString())).thenReturn(predictor); + when(mlModelManager.getWorkerNodes(anyString(), eq(FunctionName.REMOTE), eq(true))).thenReturn(new String[] { "node1" }); + taskRunner.dispatchTask(FunctionName.REMOTE, remoteInputRequest, transportService, listener); + verify(client, never()).get(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLTaskResponse.class); + verify(listener).onResponse(argumentCaptor.capture()); + } + + public void testValidateBatchPredictionFailure() throws IOException { + setupMocks(true, false, false, false); + RemoteInferenceInputDataSet remoteInputDataSet = RemoteInferenceInputDataSet.builder().predictMode(PredictMode.BATCH).build(); + MLPredictionTaskRequest remoteInputRequest = MLPredictionTaskRequest + .builder() + .modelId("test_model") + .mlInput(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInputDataSet).build()) + .build(); + Predictable predictor = mock(Predictable.class); + when(predictor.isModelReady()).thenReturn(true); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener + .onResponse(MLTaskResponse.builder().output(ModelTensorOutput.builder().mlModelOutputs(List.of()).build()).build()); + return null; + }).when(predictor).asyncPredict(any(), any()); + + IndexResponse indexResponse = mock(IndexResponse.class); + when(indexResponse.getId()).thenReturn("mockTaskId"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(indexResponse); + return null; + }).when(mlTaskManager).createMLTask(any(MLTask.class), Mockito.isA(ActionListener.class)); + + when(mlModelManager.getPredictor(anyString())).thenReturn(predictor); + when(mlModelManager.getWorkerNodes(anyString(), eq(FunctionName.REMOTE), eq(true))).thenReturn(new String[] { "node1" }); + taskRunner.dispatchTask(FunctionName.REMOTE, remoteInputRequest, transportService, listener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(argumentCaptor.capture()); + assertEquals("Unable to create batch transform job", argumentCaptor.getValue().getMessage()); + } + public void testValidateModelTensorOutputFailed() { exceptionRule.expect(OpenSearchStatusException.class); ModelTensor modelTensor = ModelTensor