From f78c89c878df6f4d39ca8723b63d74810ed1b3f9 Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Fri, 19 Jul 2024 21:17:32 -0500 Subject: [PATCH] add task to batch predict request Signed-off-by: Bhavana Ramaram --- .../org/opensearch/ml/common/CommonValue.java | 7 +- .../java/org/opensearch/ml/common/MLTask.java | 47 ++++- .../org/opensearch/ml/common/MLTaskType.java | 1 + .../ml/common/connector/ConnectorAction.java | 4 +- .../ml/common/transport/MLTaskResponse.java | 34 +++- .../action/tasks/GetTaskTransportAction.java | 172 +++++++++++++++++- .../ml/task/MLPredictTaskRunner.java | 63 ++++++- .../tasks/GetTaskTransportActionTests.java | 26 ++- 8 files changed, 339 insertions(+), 15 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 422467241b..1dbc794f8b 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -62,7 +62,7 @@ public class CommonValue { public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2; public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 11; public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector"; - public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2; + public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 3; public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 3; public static final String ML_CONFIG_INDEX = ".plugins-ml-config"; public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 2; @@ -359,6 +359,10 @@ public class CommonValue { + "\" : {\"type\" : \"boolean\"}, \n" + USER_FIELD_MAPPING + " }\n" + + "}" + + MLTask.TRANSFORM_JOB_FIELD + + "\" : {\"type\": \"flat_object\"}\n" + + " }\n" + "}"; public static final String ML_CONNECTOR_INDEX_MAPPING = "{\n" @@ -537,4 +541,5 @@ public class CommonValue { public static final Version VERSION_2_12_0 = Version.fromString("2.12.0"); public static final Version VERSION_2_13_0 = Version.fromString("2.13.0"); public static final Version VERSION_2_14_0 = Version.fromString("2.14.0"); + public static final Version VERSION_2_16_0 = Version.fromString("2.16.0"); } diff --git a/common/src/main/java/org/opensearch/ml/common/MLTask.java b/common/src/main/java/org/opensearch/ml/common/MLTask.java index 229bba5771..813c72f2e4 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLTask.java +++ b/common/src/main/java/org/opensearch/ml/common/MLTask.java @@ -9,6 +9,7 @@ import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.Setter; +import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -17,15 +18,22 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.dataset.MLInputDataType; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Map; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.USER; +import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; +import static org.opensearch.ml.common.utils.StringUtils.gson; @Getter @EqualsAndHashCode @@ -44,6 +52,8 @@ public class MLTask implements ToXContentObject, Writeable { public static final String LAST_UPDATE_TIME_FIELD = "last_update_time"; public static final String ERROR_FIELD = "error"; public static final String IS_ASYNC_TASK_FIELD = "is_async"; + public static final String TRANSFORM_JOB_FIELD = "transform_job"; + public static final Version MINIMAL_SUPPORTED_VERSION_FOR_BATCH_TRANSFORM_JOB = CommonValue.VERSION_2_16_0; @Setter private String taskId; @@ -65,6 +75,8 @@ public class MLTask implements ToXContentObject, Writeable { private String error; private User user; // TODO: support document level access control later private boolean async; + @Setter + private Map transformJob; @Builder(toBuilder = true) public MLTask( @@ -81,7 +93,8 @@ public MLTask( Instant lastUpdateTime, String error, User user, - boolean async + boolean async, + Map transformJob ) { this.taskId = taskId; this.modelId = modelId; @@ -97,9 +110,11 @@ public MLTask( this.error = error; this.user = user; this.async = async; + this.transformJob = transformJob; } public MLTask(StreamInput input) throws IOException { + Version streamInputVersion = input.getVersion(); this.taskId = input.readOptionalString(); this.modelId = input.readOptionalString(); this.taskType = input.readEnum(MLTaskType.class); @@ -122,10 +137,17 @@ public MLTask(StreamInput input) throws IOException { this.user = null; } this.async = input.readBoolean(); + if (streamInputVersion.onOrAfter(MLTask.MINIMAL_SUPPORTED_VERSION_FOR_BATCH_TRANSFORM_JOB)) { + if (input.readBoolean()) { + String mapStr = input.readString(); + this.transformJob = gson.fromJson(mapStr, Map.class); + } + } } @Override public void writeTo(StreamOutput out) throws IOException { + Version streamOutputVersion = out.getVersion(); out.writeOptionalString(taskId); out.writeOptionalString(modelId); out.writeEnum(taskType); @@ -149,6 +171,21 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } out.writeBoolean(async); + if (streamOutputVersion.onOrAfter(MLTask.MINIMAL_SUPPORTED_VERSION_FOR_BATCH_TRANSFORM_JOB)) { + if (transformJob != null) { + out.writeBoolean(true); + try { + AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + out.writeString(gson.toJson(transformJob)); + return null; + }); + } catch (PrivilegedActionException e) { + throw new RuntimeException(e); + } + } else { + out.writeBoolean(false); + } + } } @Override @@ -194,6 +231,9 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params builder.field(USER, user); } builder.field(IS_ASYNC_TASK_FIELD, async); + if (transformJob != null) { + builder.field(TRANSFORM_JOB_FIELD, transformJob); + } return builder.endObject(); } @@ -217,6 +257,7 @@ public static MLTask parse(XContentParser parser) throws IOException { String error = null; User user = null; boolean async = false; + Map transformJob = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -274,6 +315,9 @@ public static MLTask parse(XContentParser parser) throws IOException { case IS_ASYNC_TASK_FIELD: async = parser.booleanValue(); break; + case TRANSFORM_JOB_FIELD: + transformJob = parser.map(); + break; default: parser.skipChildren(); break; @@ -294,6 +338,7 @@ public static MLTask parse(XContentParser parser) throws IOException { .error(error) .user(user) .async(async) + .transformJob(transformJob) .build(); } } diff --git a/common/src/main/java/org/opensearch/ml/common/MLTaskType.java b/common/src/main/java/org/opensearch/ml/common/MLTaskType.java index db2f67f369..fee55bc712 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLTaskType.java +++ b/common/src/main/java/org/opensearch/ml/common/MLTaskType.java @@ -8,6 +8,7 @@ public enum MLTaskType { TRAINING, PREDICTION, + BATCH_PREDICTION, TRAINING_AND_PREDICTION, EXECUTION, @Deprecated diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java index 221a0c5758..9832dceb72 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java @@ -184,6 +184,8 @@ public static ConnectorAction parse(XContentParser parser) throws IOException { public enum ActionType { PREDICT, EXECUTE, - BATCH + BATCH, + CANCEL_BATCH, + BATCH_STATUS } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/MLTaskResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/MLTaskResponse.java index 4358538741..e7fc414583 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/MLTaskResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/MLTaskResponse.java @@ -28,6 +28,9 @@ public class MLTaskResponse extends ActionResponse implements ToXContentObject { MLOutput output; + private String taskId; + private String status; + @Builder public MLTaskResponse(MLOutput output) { this.output = output; @@ -38,9 +41,21 @@ public MLTaskResponse(StreamInput in) throws IOException { output = MLOutput.fromStream(in); } + public MLTaskResponse(String taskId, String status) { + this.taskId = taskId; + this.status = status; + } + @Override public void writeTo(StreamOutput out) throws IOException { - output.writeTo(out); + if (this.output != null) { + out.writeBoolean(true); + output.writeTo(out); + } else { + out.writeBoolean(false); + } + out.writeOptionalString(this.taskId); + out.writeOptionalString(this.status); } public static MLTaskResponse fromActionResponse(ActionResponse actionResponse) { @@ -60,7 +75,20 @@ public static MLTaskResponse fromActionResponse(ActionResponse actionResponse) { } @Override - public XContentBuilder toXContent(final XContentBuilder builder, final Params params) throws IOException { - return output.toXContent(builder, params); + public XContentBuilder toXContent(final XContentBuilder xContentBuilder, final Params params) throws IOException { + if (output != null) { + return output.toXContent(xContentBuilder, params); + } else { + XContentBuilder builder = xContentBuilder.startObject(); + if (taskId != null) { + builder.field("task_id", this.taskId); + } + if (status != null) { + builder.field("status", this.status); + } + + builder.endObject(); + return builder; + } } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java index 88c05f71c1..73e7d39b2b 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java @@ -6,50 +6,120 @@ package org.opensearch.ml.action.tasks; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; +import static org.opensearch.ml.common.MLTask.STATE_FIELD; +import static org.opensearch.ml.common.MLTask.TRANSFORM_JOB_FIELD; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.BATCH_STATUS; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; +import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.escapeRemoteInferenceInputData; +import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput; +import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.*; +import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL; +import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT; import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; +import org.opensearch.OpenSearchException; +import org.opensearch.OpenSearchParseException; import org.opensearch.OpenSearchStatusException; +import org.opensearch.ResourceNotFoundException; import org.opensearch.action.ActionRequest; import org.opensearch.action.get.GetRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.collect.Tuple; import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.TokenBucket; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; +import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.PredictMode; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.exception.MLResourceNotFoundException; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.model.MLGuard; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.task.MLTaskGetAction; import org.opensearch.ml.common.transport.task.MLTaskGetRequest; import org.opensearch.ml.common.transport.task.MLTaskGetResponse; +import org.opensearch.ml.engine.MLEngineClassLoader; +import org.opensearch.ml.engine.Predictable; +import org.opensearch.ml.engine.algorithms.remote.ExecutionContext; +import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor; +import org.opensearch.ml.engine.annotation.ConnectorExecutor; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.model.MLModelCacheHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.task.MLTaskManager; +import org.opensearch.script.ScriptService; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; +import org.opensearch.action.support.ThreadedActionListener; + import lombok.extern.log4j.Log4j2; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; + @Log4j2 public class GetTaskTransportAction extends HandledTransportAction { Client client; NamedXContentRegistry xContentRegistry; + ClusterService clusterService; + ScriptService scriptService; + + ConnectorAccessControlHelper connectorAccessControlHelper; + EncryptorImpl encryptor; + MLModelManager mlModelManager; + + MLTaskManager mlTaskManager; + MLModelCacheHelper modelCacheHelper; + @Inject public GetTaskTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, - NamedXContentRegistry xContentRegistry + NamedXContentRegistry xContentRegistry, + ClusterService clusterService, + ScriptService scriptService, + ConnectorAccessControlHelper connectorAccessControlHelper, + EncryptorImpl encryptor, + MLTaskManager mlTaskManager, + MLModelManager mlModelManager ) { super(MLTaskGetAction.NAME, transportService, actionFilters, MLTaskGetRequest::new); this.client = client; this.xContentRegistry = xContentRegistry; + this.clusterService = clusterService; + this.scriptService = scriptService; + this.connectorAccessControlHelper = connectorAccessControlHelper; + this.encryptor = encryptor; + this.mlTaskManager = mlTaskManager; + this.mlModelManager = mlModelManager; } - @Override protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.fromActionRequest(request); @@ -64,7 +134,13 @@ protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + + Map transformJob = mlTask.getTransformJob(); + + Map parameters = new HashMap<>(); + for (Map.Entry entry : transformJob.entrySet()) { + if (entry.getValue() instanceof String) { + parameters.put(entry.getKey(), (String) entry.getValue()); + } else { + log.debug("Value for key " + entry.getKey() + " is not a String"); + } + } + + if (parameters.containsKey("TransformJobArn") && parameters.get("TransformJobArn") != null) { + String jobArn = parameters.get("TransformJobArn"); + String transformJobName = jobArn.substring(jobArn.lastIndexOf("/") + 1); + parameters.put("TransformJobName", transformJobName); + parameters.remove("TransformJobArn"); + } + RemoteInferenceInputDataSet inferenceInputDataSet = new RemoteInferenceInputDataSet(parameters, PredictMode.BATCH); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inferenceInputDataSet).build(); + String modelId = mlTask.getModelId(); + + try { + mlModelManager.getModel( + modelId, null, null, ActionListener.wrap(model -> { + if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) { + ActionListener listener = ActionListener.wrap(connector -> { + connector.decrypt(BATCH_STATUS.name(), (credential) -> encryptor.decrypt(credential)); + RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader + .initInstance(connector.getProtocol(), connector, Connector.class); + connectorExecutor.setScriptService(scriptService); + connectorExecutor.setClusterService(clusterService); + connectorExecutor.setClient(client); + connectorExecutor.setXContentRegistry(xContentRegistry); + connectorExecutor + .executeAction(BATCH_STATUS.name(), mlInput, ActionListener.wrap(taskResponse -> { + processTaskResponse(mlTask, taskResponse, transformJob, actionListener); + }, e -> { + actionListener.onFailure(e); + })); + }, e -> { + log.error("Failed to get connector " + model.getConnectorId(), e); + actionListener.onFailure(e); + }); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + connectorAccessControlHelper.getConnector(client, model.getConnectorId(), ActionListener.runBefore(listener, threadContext::restore)); + } + } else { + actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId())); + } + }, e -> { + log.error("Failed to retrieve the ML model with the given ID", e); + actionListener.onFailure(e); + })); + } catch (Exception e) { + // fetch the connector + log.error("Unable to fetch status for ml task ", e); + } + } + + private void processTaskResponse(MLTask mlTask, MLTaskResponse taskResponse, Map transformJob, ActionListener actionListener) { + try { + ModelTensorOutput tensorOutput = (ModelTensorOutput) taskResponse.getOutput(); + if (tensorOutput != null && tensorOutput.getMlModelOutputs() != null && !tensorOutput.getMlModelOutputs().isEmpty()) { + ModelTensors modelOutput = tensorOutput.getMlModelOutputs().get(0); + if (modelOutput.getMlModelTensors() != null && !modelOutput.getMlModelTensors().isEmpty()) { + Map transformJobStatus = (Map) modelOutput.getMlModelTensors().get(0).getDataAsMap(); + if (transformJobStatus != null) { + transformJob.putAll(transformJobStatus); + // TODO: Delete task from index based on status + actionListener.onResponse(MLTaskGetResponse.builder().mlTask(mlTask).build()); + } else { + log.debug("Transform job status is null."); + actionListener.onFailure(new ResourceNotFoundException("Couldn't fetch status of the transform job")); + } + } else { + log.debug("ML Model Tensors are null or empty."); + actionListener.onFailure(new ResourceNotFoundException("Couldn't fetch status of the transform job")); + } + } else { + log.debug("ML Model Outputs are null or empty."); + actionListener.onFailure(new ResourceNotFoundException("Couldn't fetch status of the transform job")); + } + } catch (Exception e) { + log.error("Unable to fetch status for ml task ", e); + } } } + diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index b341f4c9f5..28f642e77d 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -8,17 +8,30 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.MLModel.ALGORITHM_FIELD; +import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD; +import static org.opensearch.ml.common.MLTask.STATE_FIELD; +import static org.opensearch.ml.common.MLTask.TASK_ID_FIELD; +import static org.opensearch.ml.common.MLTask.TRANSFORM_JOB_FIELD; +import static org.opensearch.ml.common.MLTaskState.COMPLETED; +import static org.opensearch.ml.common.MLTaskState.CREATED; +import static org.opensearch.ml.common.MLTaskState.FAILED; import static org.opensearch.ml.common.utils.StringUtils.getErrorMessage; import static org.opensearch.ml.permission.AccessController.checkUserPermissions; import static org.opensearch.ml.permission.AccessController.getUserContext; import static org.opensearch.ml.plugin.MachineLearningPlugin.PREDICT_THREAD_POOL; import static org.opensearch.ml.plugin.MachineLearningPlugin.REMOTE_PREDICT_THREAD_POOL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE; +import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT; +import static org.opensearch.ml.utils.MLExceptionUtils.logException; import java.time.Instant; import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; import java.util.UUID; +import org.apache.commons.lang3.RandomStringUtils; +import org.apache.commons.lang3.StringUtils; import org.opensearch.OpenSearchException; import org.opensearch.OpenSearchStatusException; import org.opensearch.ResourceNotFoundException; @@ -47,12 +60,16 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.PredictMode; import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLPredictionOutput; +import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest; @@ -66,6 +83,7 @@ import org.opensearch.ml.stats.MLActionLevelStat; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStats; +import org.opensearch.ml.utils.MLExceptionUtils; import org.opensearch.ml.utils.MLNodeUtils; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportResponseHandler; @@ -226,11 +244,18 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener dataFrameActionListener = ActionListener.wrap(dataSet -> { @@ -334,12 +358,41 @@ private void runPredict( if (mlInput.getAlgorithm() == FunctionName.REMOTE) { long startTime = System.nanoTime(); ActionListener trackPredictDurationListener = ActionListener.wrap(output -> { + if (output.getOutput() instanceof ModelTensorOutput) { validateOutputSchema(modelId, (ModelTensorOutput) output.getOutput()); } - handleAsyncMLTaskComplete(mlTask); - mlModelManager.trackPredictDuration(modelId, startTime); - internalListener.onResponse(output); + if (mlTask.getTaskType().equals(MLTaskType.BATCH_PREDICTION)) { + Map transformJob = new HashMap<>(); + ModelTensorOutput tensorOutput = (ModelTensorOutput) output.getOutput(); + if (tensorOutput != null && tensorOutput.getMlModelOutputs() != null && !tensorOutput.getMlModelOutputs().isEmpty()) { + ModelTensors modelOutput = tensorOutput.getMlModelOutputs().get(0); + if (modelOutput.getMlModelTensors() != null && !modelOutput.getMlModelTensors().isEmpty()) { + Map dataAsMap = (Map) tensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap(); + transformJob.putAll(dataAsMap); + mlTask.setTransformJob(transformJob); + mlTask.setTaskId(null); + mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> { + String taskId = response.getId(); + MLTaskResponse predictOutput = new MLTaskResponse(taskId, MLTaskState.CREATED.name()); + internalListener.onResponse(predictOutput); + }, e -> { + logException("Failed to create task for batch predict model", e, log); + internalListener.onFailure(e); + })); + } else { + log.debug("ML Model Tensors are null or empty."); + internalListener.onFailure(new ResourceNotFoundException("Couldn't fetch status of the transform job")); + } } + else { + log.debug("ML Model Outputs are null or empty."); + internalListener.onFailure(new ResourceNotFoundException("Couldn't fetch the batch transform job result")); + } + } else { + handleAsyncMLTaskComplete(mlTask); + mlModelManager.trackPredictDuration(modelId, startTime); + internalListener.onResponse(output); + } }, e -> handlePredictFailure(mlTask, internalListener, e, false, modelId)); predictor.asyncPredict(mlInput, trackPredictDurationListener); } else { diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java index 83da0f8273..9e3e7fc0d3 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java @@ -22,6 +22,7 @@ import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; @@ -29,6 +30,11 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.transport.task.MLTaskGetRequest; import org.opensearch.ml.common.transport.task.MLTaskGetResponse; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.task.MLTaskManager; +import org.opensearch.script.ScriptService; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -46,11 +52,26 @@ public class GetTaskTransportActionTests extends OpenSearchTestCase { @Mock TransportService transportService; + @Mock + private ClusterService clusterService; + @Mock + private ScriptService scriptService; + @Mock ActionFilters actionFilters; + @Mock + private ConnectorAccessControlHelper connectorAccessControlHelper; + + @Mock + private EncryptorImpl encryptor; @Mock ActionListener actionListener; + @Mock + private MLModelManager mlModelManager; + + @Mock + private MLTaskManager mlTaskManager; @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -64,7 +85,10 @@ public void setup() throws IOException { MockitoAnnotations.openMocks(this); mlTaskGetRequest = MLTaskGetRequest.builder().taskId("test_id").build(); - getTaskTransportAction = spy(new GetTaskTransportAction(transportService, actionFilters, client, xContentRegistry)); + getTaskTransportAction = spy(new GetTaskTransportAction(transportService, actionFilters, client, xContentRegistry, clusterService, + scriptService, + connectorAccessControlHelper, + encryptor, mlTaskManager, mlModelManager)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings);