From 224f8fcbdedc1b8bde9b71166e0d415e5eab1882 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Thu, 5 Sep 2024 10:32:26 -0700 Subject: [PATCH] support get batch transform job status in get task API (#2825) (#2893) * support get batch transform job status in get task API Signed-off-by: Bhavana Ramaram * add cancel batch prediction job API for offline inference Signed-off-by: Bhavana Ramaram * add unit tests and address comments Signed-off-by: Bhavana Ramaram * stash context for get model Signed-off-by: Bhavana Ramaram * apply spotlessJava and exclude from test coverage Signed-off-by: Bhavana Ramaram --------- Signed-off-by: Bhavana Ramaram (cherry picked from commit 8da7bd235d27f4eef59583a7ac7f8b8aee6d856d) Co-authored-by: Bhavana Ramaram --- .../org/opensearch/ml/common/CommonValue.java | 6 +- .../java/org/opensearch/ml/common/MLTask.java | 33 +- .../org/opensearch/ml/common/MLTaskType.java | 1 + .../ml/common/connector/ConnectorAction.java | 4 +- .../ml/common/output/model/ModelTensors.java | 5 + .../task/MLCancelBatchJobAction.java | 17 + .../task/MLCancelBatchJobRequest.java | 70 ++++ .../task/MLCancelBatchJobResponse.java | 64 ++++ .../task/MLCancelBatchJobRequestTest.java | 75 +++++ .../task/MLCancelBatchJobResponseTest.java | 36 ++ .../algorithms/remote/ConnectorUtils.java | 5 +- .../remote/MLSdkAsyncHttpResponseHandler.java | 13 +- plugin/build.gradle | 4 +- .../tasks/CancelBatchJobTransportAction.java | 241 ++++++++++++++ .../action/tasks/GetTaskTransportAction.java | 198 ++++++++++- .../ml/plugin/MachineLearningPlugin.java | 10 +- .../ml/rest/RestMLCancelBatchJobAction.java | 68 ++++ .../ml/task/MLPredictTaskRunner.java | 69 +++- .../CancelBatchJobTransportActionTests.java | 309 ++++++++++++++++++ .../tasks/GetTaskTransportActionTests.java | 228 ++++++++++++- .../rest/RestMLCancelBatchJobActionTests.java | 105 ++++++ .../ml/task/MLPredictTaskRunnerTests.java | 107 ++++++ 22 files changed, 1650 insertions(+), 18 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobAction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobRequest.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobResponse.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobRequestTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobResponseTest.java create mode 100644 plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMLCancelBatchJobAction.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLCancelBatchJobActionTests.java diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 06f917ee9d..edf76dc35e 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -66,7 +66,7 @@ public class CommonValue { public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2; public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 11; public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector"; - public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2; + public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 3; public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 3; public static final String ML_CONFIG_INDEX = ".plugins-ml-config"; public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 3; @@ -393,6 +393,10 @@ public class CommonValue { + "\" : {\"type\" : \"boolean\"}, \n" + USER_FIELD_MAPPING + " }\n" + + "}" + + MLTask.REMOTE_JOB_FIELD + + "\" : {\"type\": \"flat_object\"}\n" + + " }\n" + "}"; public static final String ML_CONNECTOR_INDEX_MAPPING = "{\n" diff --git a/common/src/main/java/org/opensearch/ml/common/MLTask.java b/common/src/main/java/org/opensearch/ml/common/MLTask.java index a810fa5159..1165628711 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLTask.java +++ b/common/src/main/java/org/opensearch/ml/common/MLTask.java @@ -13,7 +13,9 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Map; +import org.opensearch.Version; import org.opensearch.commons.authuser.User; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -45,6 +47,8 @@ public class MLTask implements ToXContentObject, Writeable { public static final String LAST_UPDATE_TIME_FIELD = "last_update_time"; public static final String ERROR_FIELD = "error"; public static final String IS_ASYNC_TASK_FIELD = "is_async"; + public static final String REMOTE_JOB_FIELD = "remote_job"; + public static final Version MINIMAL_SUPPORTED_VERSION_FOR_BATCH_PREDICTION_JOB = CommonValue.VERSION_2_17_0; @Setter private String taskId; @@ -66,6 +70,8 @@ public class MLTask implements ToXContentObject, Writeable { private String error; private User user; // TODO: support document level access control later private boolean async; + @Setter + private Map remoteJob; @Builder(toBuilder = true) public MLTask( @@ -82,7 +88,8 @@ public MLTask( Instant lastUpdateTime, String error, User user, - boolean async + boolean async, + Map remoteJob ) { this.taskId = taskId; this.modelId = modelId; @@ -98,9 +105,11 @@ public MLTask( this.error = error; this.user = user; this.async = async; + this.remoteJob = remoteJob; } public MLTask(StreamInput input) throws IOException { + Version streamInputVersion = input.getVersion(); this.taskId = input.readOptionalString(); this.modelId = input.readOptionalString(); this.taskType = input.readEnum(MLTaskType.class); @@ -123,10 +132,16 @@ public MLTask(StreamInput input) throws IOException { this.user = null; } this.async = input.readBoolean(); + if (streamInputVersion.onOrAfter(MLTask.MINIMAL_SUPPORTED_VERSION_FOR_BATCH_PREDICTION_JOB)) { + if (input.readBoolean()) { + this.remoteJob = input.readMap(s -> s.readString(), s -> s.readGenericValue()); + } + } } @Override public void writeTo(StreamOutput out) throws IOException { + Version streamOutputVersion = out.getVersion(); out.writeOptionalString(taskId); out.writeOptionalString(modelId); out.writeEnum(taskType); @@ -150,6 +165,14 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } out.writeBoolean(async); + if (streamOutputVersion.onOrAfter(MLTask.MINIMAL_SUPPORTED_VERSION_FOR_BATCH_PREDICTION_JOB)) { + if (remoteJob != null) { + out.writeBoolean(true); + out.writeMap(remoteJob, StreamOutput::writeString, StreamOutput::writeGenericValue); + } else { + out.writeBoolean(false); + } + } } @Override @@ -195,6 +218,9 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params builder.field(USER, user); } builder.field(IS_ASYNC_TASK_FIELD, async); + if (remoteJob != null) { + builder.field(REMOTE_JOB_FIELD, remoteJob); + } return builder.endObject(); } @@ -218,6 +244,7 @@ public static MLTask parse(XContentParser parser) throws IOException { String error = null; User user = null; boolean async = false; + Map remoteJob = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -275,6 +302,9 @@ public static MLTask parse(XContentParser parser) throws IOException { case IS_ASYNC_TASK_FIELD: async = parser.booleanValue(); break; + case REMOTE_JOB_FIELD: + remoteJob = parser.map(); + break; default: parser.skipChildren(); break; @@ -296,6 +326,7 @@ public static MLTask parse(XContentParser parser) throws IOException { .error(error) .user(user) .async(async) + .remoteJob(remoteJob) .build(); } } diff --git a/common/src/main/java/org/opensearch/ml/common/MLTaskType.java b/common/src/main/java/org/opensearch/ml/common/MLTaskType.java index e17b36a4dd..179bf152cd 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLTaskType.java +++ b/common/src/main/java/org/opensearch/ml/common/MLTaskType.java @@ -8,6 +8,7 @@ public enum MLTaskType { TRAINING, PREDICTION, + BATCH_PREDICTION, TRAINING_AND_PREDICTION, EXECUTION, @Deprecated diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java index 93fb5cca57..b62337d49f 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java @@ -188,7 +188,9 @@ public static ConnectorAction parse(XContentParser parser) throws IOException { public enum ActionType { PREDICT, EXECUTE, - BATCH_PREDICT; + BATCH_PREDICT, + CANCEL_BATCH_PREDICT, + BATCH_PREDICT_STATUS; public static ActionType from(String value) { try { diff --git a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java index c3413a179f..5622057951 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java +++ b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java @@ -36,6 +36,11 @@ public ModelTensors(List mlModelTensors) { this.mlModelTensors = mlModelTensors; } + @Builder + public ModelTensors(Integer statusCode) { + this.statusCode = statusCode; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobAction.java b/common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobAction.java new file mode 100644 index 0000000000..6ea26c9eb3 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobAction.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.task; + +import org.opensearch.action.ActionType; + +public class MLCancelBatchJobAction extends ActionType { + public static final MLCancelBatchJobAction INSTANCE = new MLCancelBatchJobAction(); + public static final String NAME = "cluster:admin/opensearch/ml/tasks/cancel_batch_job"; + + private MLCancelBatchJobAction() { + super(NAME, MLCancelBatchJobResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobRequest.java new file mode 100644 index 0000000000..976ab69bef --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobRequest.java @@ -0,0 +1,70 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.task; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import lombok.Builder; +import lombok.Getter; + +public class MLCancelBatchJobRequest extends ActionRequest { + @Getter + String taskId; + + @Builder + public MLCancelBatchJobRequest(String taskId) { + this.taskId = taskId; + } + + public MLCancelBatchJobRequest(StreamInput in) throws IOException { + super(in); + this.taskId = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.taskId); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + + if (this.taskId == null) { + exception = addValidationError("ML task id can't be null", exception); + } + + return exception; + } + + public static MLCancelBatchJobRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLCancelBatchJobRequest) { + return (MLCancelBatchJobRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLCancelBatchJobRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into MLCancelBatchJobRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobResponse.java new file mode 100644 index 0000000000..6e97eb9647 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobResponse.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.task; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import lombok.Builder; +import lombok.Getter; + +@Getter +public class MLCancelBatchJobResponse extends ActionResponse implements ToXContentObject { + + RestStatus status; + + @Builder + public MLCancelBatchJobResponse(RestStatus status) { + this.status = status; + } + + public MLCancelBatchJobResponse(StreamInput in) throws IOException { + super(in); + status = in.readEnum(RestStatus.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeEnum(status); + } + + public static MLCancelBatchJobResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLCancelBatchJobResponse) { + return (MLCancelBatchJobResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLCancelBatchJobResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into MLTaskGetResponse", e); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { + return xContentBuilder.startObject().field("status", status).endObject(); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobRequestTest.java new file mode 100644 index 0000000000..e6e1f3838c --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobRequestTest.java @@ -0,0 +1,75 @@ +package org.opensearch.ml.common.transport.task; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class MLCancelBatchJobRequestTest { + private String taskId; + + @Before + public void setUp() { + taskId = "test_id"; + } + + @Test + public void writeTo_Success() throws IOException { + MLCancelBatchJobRequest mlCancelBatchJobRequest = MLCancelBatchJobRequest.builder().taskId(taskId).build(); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + mlCancelBatchJobRequest.writeTo(bytesStreamOutput); + MLCancelBatchJobRequest parsedTask = new MLCancelBatchJobRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals(parsedTask.getTaskId(), taskId); + } + + @Test + public void validate_Exception_NullTaskId() { + MLCancelBatchJobRequest mlCancelBatchJobRequest = MLCancelBatchJobRequest.builder().build(); + + ActionRequestValidationException exception = mlCancelBatchJobRequest.validate(); + assertEquals("Validation Failed: 1: ML task id can't be null;", exception.getMessage()); + } + + @Test + public void fromActionRequest_Success() { + MLCancelBatchJobRequest mlCancelBatchJobRequest = MLCancelBatchJobRequest.builder().taskId(taskId).build(); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + mlCancelBatchJobRequest.writeTo(out); + } + }; + MLCancelBatchJobRequest result = MLCancelBatchJobRequest.fromActionRequest(actionRequest); + assertNotSame(result, mlCancelBatchJobRequest); + assertEquals(result.getTaskId(), mlCancelBatchJobRequest.getTaskId()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequest_IOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("test"); + } + }; + MLCancelBatchJobRequest.fromActionRequest(actionRequest); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobResponseTest.java new file mode 100644 index 0000000000..4cb3837df4 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobResponseTest.java @@ -0,0 +1,36 @@ +package org.opensearch.ml.common.transport.task; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.io.IOException; + +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; + +public class MLCancelBatchJobResponseTest { + + @Test + public void writeTo_Success() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + MLCancelBatchJobResponse response = MLCancelBatchJobResponse.builder().status(RestStatus.OK).build(); + response.writeTo(bytesStreamOutput); + MLCancelBatchJobResponse parsedResponse = new MLCancelBatchJobResponse(bytesStreamOutput.bytes().streamInput()); + assertEquals(response.getStatus(), parsedResponse.getStatus()); + } + + @Test + public void toXContentTest() throws IOException { + MLCancelBatchJobResponse mlCancelBatchJobResponse1 = MLCancelBatchJobResponse.builder().status(RestStatus.OK).build(); + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + mlCancelBatchJobResponse1.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(builder); + String jsonStr = builder.toString(); + assertEquals("{\"status\":\"OK\"}", jsonStr); + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index ef4f25c79a..ccceff3d68 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -6,6 +6,7 @@ package org.opensearch.ml.engine.algorithms.remote; import static org.apache.commons.text.StringEscapeUtils.escapeJson; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.CANCEL_BATCH_PREDICT; import static org.opensearch.ml.common.connector.HttpConnector.RESPONSE_FILTER_FIELD; import static org.opensearch.ml.common.connector.MLPreProcessFunction.CONVERT_INPUT_TO_JSON_STRING; import static org.opensearch.ml.common.connector.MLPreProcessFunction.PROCESS_REMOTE_INFERENCE_INPUT; @@ -286,7 +287,9 @@ public static SdkHttpFullRequest buildSdkRequest( } else { requestBody = RequestBody.empty(); } - if (SdkHttpMethod.POST == method && 0 == requestBody.optionalContentLength().get()) { + if (SdkHttpMethod.POST == method + && 0 == requestBody.optionalContentLength().get() + && !action.equals(CANCEL_BATCH_PREDICT.toString())) { log.error("Content length is 0. Aborting request to remote model"); throw new IllegalArgumentException("Content length is 0. Aborting request to remote model"); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java index 6ea03058f0..fdb686dacc 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java @@ -8,6 +8,7 @@ package org.opensearch.ml.engine.algorithms.remote; import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.CANCEL_BATCH_PREDICT; import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput; import java.nio.ByteBuffer; @@ -169,13 +170,14 @@ public void onComplete() { } private void response() { + String body = responseBody.toString(); + if (exceptionHolder.get() != null) { actionListener.onFailure(exceptionHolder.get()); return; } - String body = responseBody.toString(); - if (Strings.isBlank(body)) { + if (Strings.isBlank(body) && !action.equals(CANCEL_BATCH_PREDICT.toString())) { log.error("Remote model response body is empty!"); actionListener.onFailure(new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST)); return; @@ -187,6 +189,13 @@ private void response() { return; } + if (action.equals(CANCEL_BATCH_PREDICT.toString())) { + ModelTensors tensors = ModelTensors.builder().statusCode(statusCode).build(); + tensors.setStatusCode(statusCode); + actionListener.onResponse(new Tuple<>(executionContext.getSequence(), tensors)); + return; + } + try { ModelTensors tensors = processOutput(action, body, connector, scriptService, parameters, mlGuard); tensors.setStatusCode(statusCode); diff --git a/plugin/build.gradle b/plugin/build.gradle index f4d93788db..e835ba4e0f 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -306,7 +306,9 @@ List jacocoExclusions = [ 'org.opensearch.ml.helper.ModelAccessControlHelper', 'org.opensearch.ml.action.models.DeleteModelTransportAction.2', 'org.opensearch.ml.model.MLModelCacheHelper', - 'org.opensearch.ml.model.MLModelCacheHelper.1' + 'org.opensearch.ml.model.MLModelCacheHelper.1', + 'org.opensearch.ml.action.tasks.CancelBatchJobTransportAction' + ] jacocoTestCoverageVerification { diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java new file mode 100644 index 0000000000..6a7fd617ae --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java @@ -0,0 +1,241 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.tasks; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.CANCEL_BATCH_PREDICT; +import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; + +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +import org.apache.hc.core5.http.HttpStatus; +import org.opensearch.OpenSearchException; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.ResourceNotFoundException; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorAction.ActionType; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.task.MLCancelBatchJobAction; +import org.opensearch.ml.common.transport.task.MLCancelBatchJobRequest; +import org.opensearch.ml.common.transport.task.MLCancelBatchJobResponse; +import org.opensearch.ml.engine.MLEngineClassLoader; +import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.model.MLModelCacheHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.task.MLTaskManager; +import org.opensearch.script.ScriptService; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class CancelBatchJobTransportAction extends HandledTransportAction { + + Client client; + NamedXContentRegistry xContentRegistry; + + ClusterService clusterService; + ScriptService scriptService; + + ConnectorAccessControlHelper connectorAccessControlHelper; + EncryptorImpl encryptor; + MLModelManager mlModelManager; + + MLTaskManager mlTaskManager; + MLModelCacheHelper modelCacheHelper; + + @Inject + public CancelBatchJobTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + NamedXContentRegistry xContentRegistry, + ClusterService clusterService, + ScriptService scriptService, + ConnectorAccessControlHelper connectorAccessControlHelper, + EncryptorImpl encryptor, + MLTaskManager mlTaskManager, + MLModelManager mlModelManager + ) { + super(MLCancelBatchJobAction.NAME, transportService, actionFilters, MLCancelBatchJobRequest::new); + this.client = client; + this.xContentRegistry = xContentRegistry; + this.clusterService = clusterService; + this.scriptService = scriptService; + this.connectorAccessControlHelper = connectorAccessControlHelper; + this.encryptor = encryptor; + this.mlTaskManager = mlTaskManager; + this.mlModelManager = mlModelManager; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + MLCancelBatchJobRequest mlCancelBatchJobRequest = MLCancelBatchJobRequest.fromActionRequest(request); + String taskId = mlCancelBatchJobRequest.getTaskId(); + GetRequest getRequest = new GetRequest(ML_TASK_INDEX).id(taskId); + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> { + log.debug("Completed Get Task Request, id:{}", taskId); + + if (r != null && r.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLTask mlTask = MLTask.parse(parser); + + // check if function is remote and task is of type batch prediction + if (mlTask.getTaskType() == MLTaskType.BATCH_PREDICTION && mlTask.getFunctionName() == FunctionName.REMOTE) { + processRemoteBatchPrediction(mlTask, actionListener); + } else { + actionListener + .onFailure(new IllegalArgumentException("The task ID you provided does not have any associated batch job")); + } + } catch (Exception e) { + log.error("Failed to parse ml task " + r.getId(), e); + actionListener.onFailure(e); + } + } else { + actionListener.onFailure(new OpenSearchStatusException("Fail to find task", RestStatus.NOT_FOUND)); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + actionListener.onFailure(new MLResourceNotFoundException("Fail to find task")); + } else { + log.error("Failed to get ML task " + taskId, e); + actionListener.onFailure(e); + } + }), () -> context.restore())); + } catch (Exception e) { + log.error("Failed to get ML task " + taskId, e); + actionListener.onFailure(e); + } + } + + private void processRemoteBatchPrediction(MLTask mlTask, ActionListener actionListener) { + Map remoteJob = mlTask.getRemoteJob(); + + Map parameters = new HashMap<>(); + for (Map.Entry entry : remoteJob.entrySet()) { + if (entry.getValue() instanceof String) { + parameters.put(entry.getKey(), (String) entry.getValue()); + } else { + log.debug("Value for key " + entry.getKey() + " is not a String"); + } + } + + // In sagemaker, to retrieve batch transform job details, we need transformJob name. So retrieving name from the arn + parameters + .computeIfAbsent( + "TransformJobName", + key -> Optional + .ofNullable(parameters.get("TransformJobArn")) + .map(jobArn -> jobArn.substring(jobArn.lastIndexOf("/") + 1)) + .orElse(null) + ); + + RemoteInferenceInputDataSet inferenceInputDataSet = new RemoteInferenceInputDataSet(parameters, ActionType.BATCH_PREDICT_STATUS); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inferenceInputDataSet).build(); + String modelId = mlTask.getModelId(); + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener getModelListener = ActionListener.wrap(model -> { + if (model.getConnector() != null) { + Connector connector = model.getConnector(); + executeConnector(connector, mlInput, actionListener); + } else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) { + ActionListener listener = ActionListener + .wrap(connector -> { executeConnector(connector, mlInput, actionListener); }, e -> { + log.error("Failed to get connector " + model.getConnectorId(), e); + actionListener.onFailure(e); + }); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + connectorAccessControlHelper + .getConnector(client, model.getConnectorId(), ActionListener.runBefore(listener, threadContext::restore)); + } + } else { + actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId())); + } + }, e -> { + log.error("Failed to retrieve the ML model with the given ID", e); + actionListener + .onFailure( + new OpenSearchStatusException("Failed to retrieve the ML model for the given task ID", RestStatus.NOT_FOUND) + ); + }); + mlModelManager.getModel(modelId, null, null, ActionListener.runBefore(getModelListener, context::restore)); + } catch (Exception e) { + log.error("Unable to fetch cancel batch job in ml task ", e); + throw new OpenSearchException("Unable to fetch cancel batch job in ml task " + e.getMessage()); + } + } + + private void executeConnector(Connector connector, MLInput mlInput, ActionListener actionListener) { + if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) { + connector.decrypt(CANCEL_BATCH_PREDICT.name(), (credential) -> encryptor.decrypt(credential)); + RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader + .initInstance(connector.getProtocol(), connector, Connector.class); + connectorExecutor.setScriptService(scriptService); + connectorExecutor.setClusterService(clusterService); + connectorExecutor.setClient(client); + connectorExecutor.setXContentRegistry(xContentRegistry); + connectorExecutor.executeAction(CANCEL_BATCH_PREDICT.name(), mlInput, ActionListener.wrap(taskResponse -> { + processTaskResponse(taskResponse, actionListener); + }, e -> { actionListener.onFailure(e); })); + } else { + actionListener + .onFailure(new OpenSearchStatusException("You don't have permission to access this connector", RestStatus.FORBIDDEN)); + } + } + + private void processTaskResponse(MLTaskResponse taskResponse, ActionListener actionListener) { + try { + ModelTensorOutput tensorOutput = (ModelTensorOutput) taskResponse.getOutput(); + if (tensorOutput != null && tensorOutput.getMlModelOutputs() != null && !tensorOutput.getMlModelOutputs().isEmpty()) { + ModelTensors modelOutput = tensorOutput.getMlModelOutputs().get(0); + if (modelOutput.getStatusCode() != null && modelOutput.getStatusCode().equals(HttpStatus.SC_OK)) { + actionListener.onResponse(new MLCancelBatchJobResponse(RestStatus.OK)); + } else { + log.debug("The status code from remote service is: " + modelOutput.getStatusCode()); + actionListener.onFailure(new OpenSearchException("Couldn't cancel the transform job. Please try again")); + } + } else { + log.debug("ML Model Outputs are null or empty."); + actionListener.onFailure(new ResourceNotFoundException("Couldn't fetch status of the transform job")); + } + } catch (Exception e) { + log.error("Unable to fetch status for ml task ", e); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java index 88c05f71c1..01b4724046 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java @@ -6,15 +6,29 @@ package org.opensearch.ml.action.tasks; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; +import static org.opensearch.ml.common.MLTask.REMOTE_JOB_FIELD; +import static org.opensearch.ml.common.MLTask.STATE_FIELD; +import static org.opensearch.ml.common.MLTaskState.CANCELLED; +import static org.opensearch.ml.common.MLTaskState.COMPLETED; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.BATCH_PREDICT_STATUS; +import static org.opensearch.ml.utils.MLExceptionUtils.logException; import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +import org.opensearch.OpenSearchException; import org.opensearch.OpenSearchStatusException; +import org.opensearch.ResourceNotFoundException; import org.opensearch.action.ActionRequest; import org.opensearch.action.get.GetRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; @@ -22,11 +36,29 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorAction.ActionType; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.exception.MLResourceNotFoundException; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.task.MLTaskGetAction; import org.opensearch.ml.common.transport.task.MLTaskGetRequest; import org.opensearch.ml.common.transport.task.MLTaskGetResponse; +import org.opensearch.ml.engine.MLEngineClassLoader; +import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.model.MLModelCacheHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.task.MLTaskManager; +import org.opensearch.script.ScriptService; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -38,16 +70,38 @@ public class GetTaskTransportAction extends HandledTransportAction actionListener) { + Map remoteJob = mlTask.getRemoteJob(); + + Map parameters = new HashMap<>(); + for (Map.Entry entry : remoteJob.entrySet()) { + if (entry.getValue() instanceof String) { + parameters.put(entry.getKey(), (String) entry.getValue()); + } else { + log.debug("Value for key " + entry.getKey() + " is not a String"); + } + } + // In sagemaker, to retrieve batch transform job details, we need transformJob name. So retrieving name from the arn + parameters + .computeIfAbsent( + "TransformJobName", + key -> Optional + .ofNullable(parameters.get("TransformJobArn")) + .map(jobArn -> jobArn.substring(jobArn.lastIndexOf("/") + 1)) + .orElse(null) + ); + + RemoteInferenceInputDataSet inferenceInputDataSet = new RemoteInferenceInputDataSet(parameters, ActionType.BATCH_PREDICT_STATUS); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inferenceInputDataSet).build(); + String modelId = mlTask.getModelId(); + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener getModelListener = ActionListener.wrap(model -> { + if (model.getConnector() != null) { + Connector connector = model.getConnector(); + executeConnector(connector, mlInput, taskId, mlTask, remoteJob, actionListener); + } else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) { + ActionListener listener = ActionListener.wrap(connector -> { + executeConnector(connector, mlInput, taskId, mlTask, remoteJob, actionListener); + }, e -> { + log.error("Failed to get connector " + model.getConnectorId(), e); + actionListener.onFailure(e); + }); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + connectorAccessControlHelper + .getConnector(client, model.getConnectorId(), ActionListener.runBefore(listener, threadContext::restore)); + } + } else { + actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId())); + } + }, e -> { + log.error("Failed to retrieve the ML model for the given task ID", e); + actionListener + .onFailure( + new OpenSearchStatusException("Failed to retrieve the ML model for the given task ID", RestStatus.NOT_FOUND) + ); + }); + mlModelManager.getModel(modelId, null, null, ActionListener.runBefore(getModelListener, context::restore)); + } catch (Exception e) { + log.error("Unable to fetch status for ml task ", e); + throw new OpenSearchException("Unable to fetch status for ml task " + e.getMessage()); + } + } + + private void executeConnector( + Connector connector, + MLInput mlInput, + String taskId, + MLTask mlTask, + Map transformJob, + ActionListener actionListener + ) { + if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) { + connector.decrypt(BATCH_PREDICT_STATUS.name(), (credential) -> encryptor.decrypt(credential)); + RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader + .initInstance(connector.getProtocol(), connector, Connector.class); + connectorExecutor.setScriptService(scriptService); + connectorExecutor.setClusterService(clusterService); + connectorExecutor.setClient(client); + connectorExecutor.setXContentRegistry(xContentRegistry); + connectorExecutor.executeAction(BATCH_PREDICT_STATUS.name(), mlInput, ActionListener.wrap(taskResponse -> { + processTaskResponse(mlTask, taskId, taskResponse, transformJob, actionListener); + }, e -> { actionListener.onFailure(e); })); + } else { + actionListener + .onFailure(new OpenSearchStatusException("You don't have permission to access this connector", RestStatus.FORBIDDEN)); + } + } + + private void processTaskResponse( + MLTask mlTask, + String taskId, + MLTaskResponse taskResponse, + Map remoteJob, + ActionListener actionListener + ) { + try { + ModelTensorOutput tensorOutput = (ModelTensorOutput) taskResponse.getOutput(); + if (tensorOutput != null && tensorOutput.getMlModelOutputs() != null && !tensorOutput.getMlModelOutputs().isEmpty()) { + ModelTensors modelOutput = tensorOutput.getMlModelOutputs().get(0); + if (modelOutput.getMlModelTensors() != null && !modelOutput.getMlModelTensors().isEmpty()) { + Map remoteJobStatus = (Map) modelOutput.getMlModelTensors().get(0).getDataAsMap(); + if (remoteJobStatus != null) { + remoteJob.putAll(remoteJobStatus); + Map updatedTask = new HashMap<>(); + updatedTask.put(REMOTE_JOB_FIELD, remoteJob); + + if ((remoteJob.containsKey("status") && remoteJob.get("status").equals("completed")) + || (remoteJob.containsKey("TransformJobStatus") && remoteJob.get("TransformJobStatus").equals("Completed"))) { + updatedTask.put(STATE_FIELD, COMPLETED); + mlTask.setState(COMPLETED); + + } else if ((remoteJob.containsKey("status") && remoteJob.get("status").equals("cancelled")) + || (remoteJob.containsKey("TransformJobStatus") && remoteJob.get("TransformJobStatus").equals("Stopped"))) { + updatedTask.put(STATE_FIELD, CANCELLED); + mlTask.setState(CANCELLED); + } + mlTaskManager.updateMLTaskDirectly(taskId, updatedTask, ActionListener.wrap(response -> { + actionListener.onResponse(MLTaskGetResponse.builder().mlTask(mlTask).build()); + }, e -> { + logException("Failed to update task for batch predict model", e, log); + actionListener.onFailure(e); + })); + } else { + log.debug("Transform job status is null."); + actionListener.onFailure(new ResourceNotFoundException("Couldn't fetch status of the transform job")); + } + } else { + log.debug("ML Model Tensors are null or empty."); + actionListener.onFailure(new ResourceNotFoundException("Couldn't fetch status of the transform job")); + } + } else { + log.debug("ML Model Outputs are null or empty."); + actionListener.onFailure(new ResourceNotFoundException("Couldn't fetch status of the transform job")); + } + } catch (Exception e) { + log.error("Unable to fetch status for ml task ", e); + } } } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index d77eb63540..4542398091 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -83,6 +83,7 @@ import org.opensearch.ml.action.stats.MLStatsNodesAction; import org.opensearch.ml.action.stats.MLStatsNodesTransportAction; import org.opensearch.ml.action.syncup.TransportSyncUpOnNodeAction; +import org.opensearch.ml.action.tasks.CancelBatchJobTransportAction; import org.opensearch.ml.action.tasks.DeleteTaskTransportAction; import org.opensearch.ml.action.tasks.GetTaskTransportAction; import org.opensearch.ml.action.tasks.SearchTaskTransportAction; @@ -151,6 +152,7 @@ import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.register.MLRegisterModelAction; import org.opensearch.ml.common.transport.sync.MLSyncUpAction; +import org.opensearch.ml.common.transport.task.MLCancelBatchJobAction; import org.opensearch.ml.common.transport.task.MLTaskDeleteAction; import org.opensearch.ml.common.transport.task.MLTaskGetAction; import org.opensearch.ml.common.transport.task.MLTaskSearchAction; @@ -219,6 +221,7 @@ import org.opensearch.ml.processor.MLInferenceSearchResponseProcessor; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; import org.opensearch.ml.rest.RestMLBatchIngestAction; +import org.opensearch.ml.rest.RestMLCancelBatchJobAction; import org.opensearch.ml.rest.RestMLCreateConnectorAction; import org.opensearch.ml.rest.RestMLCreateControllerAction; import org.opensearch.ml.rest.RestMLDeleteAgentAction; @@ -444,7 +447,8 @@ public MachineLearningPlugin(Settings settings) { new ActionHandler<>(MLListToolsAction.INSTANCE, ListToolsTransportAction.class), new ActionHandler<>(MLGetToolAction.INSTANCE, GetToolTransportAction.class), new ActionHandler<>(MLConfigGetAction.INSTANCE, GetConfigTransportAction.class), - new ActionHandler<>(MLBatchIngestionAction.INSTANCE, TransportBatchIngestionAction.class) + new ActionHandler<>(MLBatchIngestionAction.INSTANCE, TransportBatchIngestionAction.class), + new ActionHandler<>(MLCancelBatchJobAction.INSTANCE, CancelBatchJobTransportAction.class) ); } @@ -765,6 +769,7 @@ public List getRestHandlers( RestMLGetToolAction restMLGetToolAction = new RestMLGetToolAction(toolFactories); RestMLGetConfigAction restMLGetConfigAction = new RestMLGetConfigAction(); RestMLBatchIngestAction restMLBatchIngestAction = new RestMLBatchIngestAction(); + RestMLCancelBatchJobAction restMLCancelBatchJobAction = new RestMLCancelBatchJobAction(); return ImmutableList .of( restMLStatsAction, @@ -818,7 +823,8 @@ public List getRestHandlers( restMLListToolsAction, restMLGetToolAction, restMLGetConfigAction, - restMLBatchIngestAction + restMLBatchIngestAction, + restMLCancelBatchJobAction ); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLCancelBatchJobAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCancelBatchJobAction.java new file mode 100644 index 0000000000..33c7314be2 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCancelBatchJobAction.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TASK_ID; +import static org.opensearch.ml.utils.RestActionUtils.getParameterId; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.transport.task.MLCancelBatchJobAction; +import org.opensearch.ml.common.transport.task.MLCancelBatchJobRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +public class RestMLCancelBatchJobAction extends BaseRestHandler { + private static final String ML_CANCEL_BATCH_ACTION = "ml_cancel_batch_action"; + + /** + * Constructor + */ + public RestMLCancelBatchJobAction() {} + + @Override + public String getName() { + return ML_CANCEL_BATCH_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/tasks/{%s}/_cancel_batch", ML_BASE_URI, PARAMETER_TASK_ID) + ) + ); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLCancelBatchJobRequest mlCancelBatchJobRequest = getRequest(request); + return channel -> client.execute(MLCancelBatchJobAction.INSTANCE, mlCancelBatchJobRequest, new RestToXContentListener<>(channel)); + } + + /** + * Creates a MLCancelBatchJobRequest from a RestRequest + * + * @param request RestRequest + * @return MLCancelBatchJobRequest + */ + @VisibleForTesting + MLCancelBatchJobRequest getRequest(RestRequest request) throws IOException { + String taskId = getParameterId(request, PARAMETER_TASK_ID); + + return new MLCancelBatchJobRequest(taskId); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 72c43bd58f..3b2e70d4b8 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -14,9 +14,12 @@ import static org.opensearch.ml.plugin.MachineLearningPlugin.PREDICT_THREAD_POOL; import static org.opensearch.ml.plugin.MachineLearningPlugin.REMOTE_PREDICT_THREAD_POOL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE; +import static org.opensearch.ml.utils.MLExceptionUtils.logException; import java.time.Instant; import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; import java.util.UUID; import org.opensearch.OpenSearchException; @@ -48,6 +51,7 @@ import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.ConnectorAction.ActionType; import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; @@ -55,6 +59,7 @@ import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLPredictionOutput; import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest; @@ -228,11 +233,18 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener dataFrameActionListener = ActionListener.wrap(dataSet -> { @@ -336,12 +347,60 @@ private void runPredict( if (mlInput.getAlgorithm() == FunctionName.REMOTE) { long startTime = System.nanoTime(); ActionListener trackPredictDurationListener = ActionListener.wrap(output -> { + if (output.getOutput() instanceof ModelTensorOutput) { validateOutputSchema(modelId, (ModelTensorOutput) output.getOutput()); } - handleAsyncMLTaskComplete(mlTask); - mlModelManager.trackPredictDuration(modelId, startTime); - internalListener.onResponse(output); + if (mlTask.getTaskType().equals(MLTaskType.BATCH_PREDICTION)) { + Map remoteJob = new HashMap<>(); + ModelTensorOutput tensorOutput = (ModelTensorOutput) output.getOutput(); + if (tensorOutput != null + && tensorOutput.getMlModelOutputs() != null + && !tensorOutput.getMlModelOutputs().isEmpty()) { + ModelTensors modelOutput = tensorOutput.getMlModelOutputs().get(0); + if (modelOutput.getMlModelTensors() != null && !modelOutput.getMlModelTensors().isEmpty()) { + Map dataAsMap = (Map) modelOutput + .getMlModelTensors() + .get(0) + .getDataAsMap(); + if (dataAsMap != null + && (dataAsMap.containsKey("TransformJobArn") || dataAsMap.containsKey("id"))) { + remoteJob.putAll(dataAsMap); + mlTask.setRemoteJob(remoteJob); + mlTask.setTaskId(null); + mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> { + String taskId = response.getId(); + mlTask.setTaskId(taskId); + MLPredictionOutput outputBuilder = MLPredictionOutput + .builder() + .taskId(taskId) + .status(MLTaskState.CREATED.name()) + .build(); + + MLTaskResponse predictOutput = MLTaskResponse.builder().output(outputBuilder).build(); + internalListener.onResponse(predictOutput); + }, e -> { + logException("Failed to create task for batch predict model", e, log); + internalListener.onFailure(e); + })); + } else { + log.debug("Batch transform job output from remote model did not return the job ID"); + internalListener + .onFailure(new ResourceNotFoundException("Unable to create batch transform job")); + } + } else { + log.debug("ML Model Tensors are null or empty."); + internalListener.onFailure(new ResourceNotFoundException("Unable to create batch transform job")); + } + } else { + log.debug("ML Model Outputs are null or empty."); + internalListener.onFailure(new ResourceNotFoundException("Unable to create batch transform job")); + } + } else { + handleAsyncMLTaskComplete(mlTask); + mlModelManager.trackPredictDuration(modelId, startTime); + internalListener.onResponse(output); + } }, e -> handlePredictFailure(mlTask, internalListener, e, false, modelId, actionName)); predictor.asyncPredict(mlInput, trackPredictDurationListener); } else { diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java new file mode 100644 index 0000000000..99d9fbf8a1 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java @@ -0,0 +1,309 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.tasks; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.ResourceNotFoundException; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.get.GetResult; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.task.MLCancelBatchJobRequest; +import org.opensearch.ml.common.transport.task.MLCancelBatchJobResponse; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.task.MLTaskManager; +import org.opensearch.script.ScriptService; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class CancelBatchJobTransportActionTests extends OpenSearchTestCase { + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + TransportService transportService; + + @Mock + private ClusterService clusterService; + @Mock + private ScriptService scriptService; + @Mock + ClusterState clusterState; + + @Mock + private Metadata metaData; + + @Mock + ActionFilters actionFilters; + @Mock + private ConnectorAccessControlHelper connectorAccessControlHelper; + + @Mock + private EncryptorImpl encryptor; + + @Mock + ActionListener actionListener; + @Mock + private MLModelManager mlModelManager; + + @Mock + private MLTaskManager mlTaskManager; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + CancelBatchJobTransportAction cancelBatchJobTransportAction; + MLCancelBatchJobRequest mlCancelBatchJobRequest; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + mlCancelBatchJobRequest = MLCancelBatchJobRequest.builder().taskId("test_id").build(); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + doReturn(clusterState).when(clusterService).state(); + doReturn(metaData).when(clusterState).metadata(); + + doReturn(true).when(metaData).hasIndex(anyString()); + + cancelBatchJobTransportAction = spy( + new CancelBatchJobTransportAction( + transportService, + actionFilters, + client, + xContentRegistry, + clusterService, + scriptService, + connectorAccessControlHelper, + encryptor, + mlTaskManager, + mlModelManager + ) + ); + + MLModel mlModel = mock(MLModel.class); + + Connector connector = HttpConnector + .builder() + .name("test") + .protocol("http") + .version("1") + .credential(Map.of("api_key", "credential_value")) + .parameters(Map.of("param1", "value1")) + .actions( + Arrays + .asList( + ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.BATCH_PREDICT_STATUS) + .method("POST") + .url("https://api.sagemaker.us-east-1.amazonaws.com/DescribeTransformJob") + .headers(Map.of("Authorization", "Bearer ${credential.api_key}")) + .requestBody("{ \"TransformJobName\" : \"${parameters.TransformJobName}\"}") + .build() + ) + ) + .build(); + + when(mlModel.getConnectorId()).thenReturn("testConnectorID"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mlModel); + return null; + }).when(mlModelManager).getModel(eq("testModelID"), any(), any(), isA(ActionListener.class)); + + when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(connector); + return null; + }).when(connectorAccessControlHelper).getConnector(eq(client), anyString(), any()); + + } + + public void testGetTask_NullResponse() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).get(any(), any()); + cancelBatchJobTransportAction.doExecute(null, mlCancelBatchJobRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Fail to find task", argumentCaptor.getValue().getMessage()); + } + + public void testGetTask_RuntimeException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("errorMessage")); + return null; + }).when(client).get(any(), any()); + cancelBatchJobTransportAction.doExecute(null, mlCancelBatchJobRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("errorMessage", argumentCaptor.getValue().getMessage()); + } + + public void testGetTask_IndexNotFoundException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new IndexNotFoundException("Index Not Found")); + return null; + }).when(client).get(any(), any()); + cancelBatchJobTransportAction.doExecute(null, mlCancelBatchJobRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Fail to find task", argumentCaptor.getValue().getMessage()); + } + + @Ignore + public void testGetTask_SuccessBatchPredictCancel() throws IOException { + Map remoteJob = new HashMap<>(); + remoteJob.put("Status", "IN PROGRESS"); + remoteJob.put("TransformJobName", "SM-offline-batch-transform13"); + + GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(Map.of("TransformJobStatus", "COMPLETED")).build(); + ModelTensorOutput modelTensorOutput = ModelTensorOutput + .builder() + .mlModelOutputs(List.of(ModelTensors.builder().mlModelTensors(List.of(modelTensor)).build())) + .build(); + + cancelBatchJobTransportAction.doExecute(null, mlCancelBatchJobRequest, actionListener); + verify(actionListener).onResponse(any(MLCancelBatchJobResponse.class)); + } + + public void test_BatchPredictCancel_NoConnector() throws IOException { + Map remoteJob = new HashMap<>(); + remoteJob.put("Status", "IN PROGRESS"); + remoteJob.put("TransformJobName", "SM-offline-batch-transform13"); + + when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(false); + + GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + cancelBatchJobTransportAction.doExecute(null, mlCancelBatchJobRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("You don't have permission to access this connector", argumentCaptor.getValue().getMessage()); + } + + public void test_BatchPredictStatus_NoAccessToConnector() throws IOException { + Map remoteJob = new HashMap<>(); + remoteJob.put("Status", "IN PROGRESS"); + remoteJob.put("TransformJobName", "SM-offline-batch-transform13"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new ResourceNotFoundException("Failed to get connector")); + return null; + }).when(connectorAccessControlHelper).getConnector(eq(client), anyString(), any()); + + GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + cancelBatchJobTransportAction.doExecute(null, mlCancelBatchJobRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to get connector", argumentCaptor.getValue().getMessage()); + } + + public GetResponse prepareMLTask(FunctionName functionName, MLTaskType mlTaskType, Map remoteJob) throws IOException { + MLTask mlTask = MLTask + .builder() + .taskId("taskID") + .modelId("testModelID") + .functionName(functionName) + .taskType(mlTaskType) + .remoteJob(remoteJob) + .build(); + XContentBuilder content = mlTask.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + BytesReference bytesReference = BytesReference.bytes(content); + GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); + GetResponse getResponse = new GetResponse(getResult); + return getResponse; + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java index 83da0f8273..3707c89eae 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java @@ -6,29 +6,63 @@ package org.opensearch.ml.action.tasks; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import org.junit.Before; +import org.junit.Ignore; import org.junit.Rule; import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.ResourceNotFoundException; import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.get.GetResult; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.transport.task.MLTaskGetRequest; import org.opensearch.ml.common.transport.task.MLTaskGetResponse; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.task.MLTaskManager; +import org.opensearch.script.ScriptService; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -46,11 +80,31 @@ public class GetTaskTransportActionTests extends OpenSearchTestCase { @Mock TransportService transportService; + @Mock + private ClusterService clusterService; + @Mock + private ScriptService scriptService; + @Mock + ClusterState clusterState; + + @Mock + private Metadata metaData; + @Mock ActionFilters actionFilters; + @Mock + private ConnectorAccessControlHelper connectorAccessControlHelper; + + @Mock + private EncryptorImpl encryptor; @Mock ActionListener actionListener; + @Mock + private MLModelManager mlModelManager; + + @Mock + private MLTaskManager mlTaskManager; @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -64,12 +118,71 @@ public void setup() throws IOException { MockitoAnnotations.openMocks(this); mlTaskGetRequest = MLTaskGetRequest.builder().taskId("test_id").build(); - getTaskTransportAction = spy(new GetTaskTransportAction(transportService, actionFilters, client, xContentRegistry)); - Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + + doReturn(clusterState).when(clusterService).state(); + doReturn(metaData).when(clusterState).metadata(); + + doReturn(true).when(metaData).hasIndex(anyString()); + + getTaskTransportAction = spy( + new GetTaskTransportAction( + transportService, + actionFilters, + client, + xContentRegistry, + clusterService, + scriptService, + connectorAccessControlHelper, + encryptor, + mlTaskManager, + mlModelManager + ) + ); + + MLModel mlModel = mock(MLModel.class); + + Connector connector = HttpConnector + .builder() + .name("test") + .protocol("http") + .version("1") + .credential(Map.of("api_key", "credential_value")) + .parameters(Map.of("param1", "value1")) + .actions( + Arrays + .asList( + ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.BATCH_PREDICT_STATUS) + .method("POST") + .url("https://api.sagemaker.us-east-1.amazonaws.com/DescribeTransformJob") + .headers(Map.of("Authorization", "Bearer ${credential.api_key}")) + .requestBody("{ \"TransformJobName\" : \"${parameters.TransformJobName}\"}") + .build() + ) + ) + .build(); + + when(mlModel.getConnectorId()).thenReturn("testConnectorID"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mlModel); + return null; + }).when(mlModelManager).getModel(eq("testModelID"), any(), any(), isA(ActionListener.class)); + + when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(connector); + return null; + }).when(connectorAccessControlHelper).getConnector(eq(client), anyString(), any()); + } public void testGetTask_NullResponse() { @@ -107,4 +220,115 @@ public void testGetTask_IndexNotFoundException() { verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Fail to find task", argumentCaptor.getValue().getMessage()); } + + @Ignore + public void testGetTask_SuccessBatchPredictStatus() throws IOException { + Map remoteJob = new HashMap<>(); + remoteJob.put("Status", "IN PROGRESS"); + remoteJob.put("TransformJobName", "SM-offline-batch-transform13"); + + GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(Map.of("TransformJobStatus", "COMPLETED")).build(); + ModelTensorOutput modelTensorOutput = ModelTensorOutput + .builder() + .mlModelOutputs(List.of(ModelTensors.builder().mlModelTensors(List.of(modelTensor)).build())) + .build(); + + getTaskTransportAction.doExecute(null, mlTaskGetRequest, actionListener); + verify(actionListener).onResponse(any(MLTaskGetResponse.class)); + } + + public void test_BatchPredictStatus_NoConnector() throws IOException { + Map remoteJob = new HashMap<>(); + remoteJob.put("Status", "IN PROGRESS"); + remoteJob.put("TransformJobName", "SM-offline-batch-transform13"); + + when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(false); + + GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + getTaskTransportAction.doExecute(null, mlTaskGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("You don't have permission to access this connector", argumentCaptor.getValue().getMessage()); + } + + public void test_BatchPredictStatus_NoAccessToConnector() throws IOException { + Map remoteJob = new HashMap<>(); + remoteJob.put("Status", "IN PROGRESS"); + remoteJob.put("TransformJobName", "SM-offline-batch-transform13"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new ResourceNotFoundException("Failed to get connector")); + return null; + }).when(connectorAccessControlHelper).getConnector(eq(client), anyString(), any()); + + GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + getTaskTransportAction.doExecute(null, mlTaskGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to get connector", argumentCaptor.getValue().getMessage()); + } + + public void test_BatchPredictStatus_NoModel() throws IOException { + Map remoteJob = new HashMap<>(); + remoteJob.put("Status", "IN PROGRESS"); + remoteJob.put("TransformJobName", "SM-offline-batch-transform13"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new ResourceNotFoundException("Failed to get connector")); + return null; + }).when(connectorAccessControlHelper).getConnector(eq(client), anyString(), any()); + + GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + getTaskTransportAction.doExecute(null, mlTaskGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to get connector", argumentCaptor.getValue().getMessage()); + } + + public GetResponse prepareMLTask(FunctionName functionName, MLTaskType mlTaskType, Map remoteJob) throws IOException { + MLTask mlTask = MLTask + .builder() + .taskId("taskID") + .modelId("testModelID") + .functionName(functionName) + .taskType(mlTaskType) + .remoteJob(remoteJob) + .build(); + XContentBuilder content = mlTask.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + BytesReference bytesReference = BytesReference.bytes(content); + GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); + GetResponse getResponse = new GetResponse(getResult); + return getResponse; + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLCancelBatchJobActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCancelBatchJobActionTests.java new file mode 100644 index 0000000000..1498750e6a --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCancelBatchJobActionTests.java @@ -0,0 +1,105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TASK_ID; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.task.MLCancelBatchJobAction; +import org.opensearch.ml.common.transport.task.MLCancelBatchJobRequest; +import org.opensearch.ml.common.transport.task.MLCancelBatchJobResponse; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +public class RestMLCancelBatchJobActionTests extends OpenSearchTestCase { + + private RestMLCancelBatchJobAction restMLCancelBatchJobAction; + + NodeClient client; + private ThreadPool threadPool; + + @Mock + RestChannel channel; + + @Before + public void setup() { + restMLCancelBatchJobAction = new RestMLCancelBatchJobAction(); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(client).execute(eq(MLCancelBatchJobAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLCancelBatchJobAction mlCancelBatchJobAction = new RestMLCancelBatchJobAction(); + assertNotNull(mlCancelBatchJobAction); + } + + public void testGetName() { + String actionName = restMLCancelBatchJobAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_cancel_batch_action", actionName); + } + + public void testRoutes() { + List routes = restMLCancelBatchJobAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.POST, route.getMethod()); + assertEquals("/_plugins/_ml/tasks/{task_id}/_cancel_batch", route.getPath()); + } + + public void test_PrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restMLCancelBatchJobAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLCancelBatchJobRequest.class); + verify(client, times(1)).execute(eq(MLCancelBatchJobAction.INSTANCE), argumentCaptor.capture(), any()); + String taskId = argumentCaptor.getValue().getTaskId(); + assertEquals(taskId, "test_id"); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + params.put(PARAMETER_TASK_ID, "test_id"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + return request; + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index cbde703543..064008a9c4 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -25,10 +25,12 @@ import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchStatusException; import org.opensearch.Version; import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexResponse; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; @@ -49,11 +51,13 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.SearchQueryInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.parameter.rcf.BatchRCFParams; import org.opensearch.ml.common.output.MLPredictionOutput; @@ -412,6 +416,109 @@ public void testValidateModelTensorOutputSuccess() { taskRunner.validateOutputSchema("testId", modelTensorOutput); } + public void testValidateBatchPredictionSuccess() throws IOException { + setupMocks(true, false, false, false); + RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet + .builder() + .parameters( + Map + .of( + "messages", + "[{\\\"role\\\":\\\"system\\\",\\\"content\\\":\\\"You are a helpful assistant.\\\"}," + + "{\\\"role\\\":\\\"user\\\",\\\"content\\\":\\\"Hello!\\\"}]" + ) + ) + .actionType(ConnectorAction.ActionType.BATCH_PREDICT) + .build(); + MLPredictionTaskRequest remoteInputRequest = MLPredictionTaskRequest + .builder() + .modelId("test_model") + .mlInput(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInferenceInputDataSet).build()) + .build(); + Predictable predictor = mock(Predictable.class); + when(predictor.isModelReady()).thenReturn(true); + ModelTensor modelTensor = ModelTensor + .builder() + .name("response") + .dataAsMap(Map.of("TransformJobArn", "arn:aws:sagemaker:us-east-1:802041417063:transform-job/batch-transform-01")) + .build(); + Map modelInterface = Map + .of( + "output", + "{\"properties\":{\"inference_results\":{\"description\":\"This is a test description field\"," + "\"type\":\"array\"}}}" + ); + ModelTensorOutput modelTensorOutput = ModelTensorOutput + .builder() + .mlModelOutputs(List.of(ModelTensors.builder().mlModelTensors(List.of(modelTensor)).build())) + .build(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(MLTaskResponse.builder().output(modelTensorOutput).build()); + return null; + }).when(predictor).asyncPredict(any(), any()); + + IndexResponse indexResponse = mock(IndexResponse.class); + when(indexResponse.getId()).thenReturn("mockTaskId"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(indexResponse); + return null; + }).when(mlTaskManager).createMLTask(any(MLTask.class), Mockito.isA(ActionListener.class)); + + when(mlModelManager.getModelInterface(any())).thenReturn(modelInterface); + + when(mlModelManager.getPredictor(anyString())).thenReturn(predictor); + when(mlModelManager.getWorkerNodes(anyString(), eq(FunctionName.REMOTE), eq(true))).thenReturn(new String[] { "node1" }); + taskRunner.dispatchTask(FunctionName.REMOTE, remoteInputRequest, transportService, listener); + verify(client, never()).get(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLTaskResponse.class); + verify(listener).onResponse(argumentCaptor.capture()); + } + + public void testValidateBatchPredictionFailure() throws IOException { + setupMocks(true, false, false, false); + RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet + .builder() + .parameters( + Map + .of( + "messages", + "[{\\\"role\\\":\\\"system\\\",\\\"content\\\":\\\"You are a helpful assistant.\\\"}," + + "{\\\"role\\\":\\\"user\\\",\\\"content\\\":\\\"Hello!\\\"}]" + ) + ) + .actionType(ConnectorAction.ActionType.BATCH_PREDICT) + .build(); + MLPredictionTaskRequest remoteInputRequest = MLPredictionTaskRequest + .builder() + .modelId("test_model") + .mlInput(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInferenceInputDataSet).build()) + .build(); + Predictable predictor = mock(Predictable.class); + when(predictor.isModelReady()).thenReturn(true); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener + .onResponse(MLTaskResponse.builder().output(ModelTensorOutput.builder().mlModelOutputs(List.of()).build()).build()); + return null; + }).when(predictor).asyncPredict(any(), any()); + + IndexResponse indexResponse = mock(IndexResponse.class); + when(indexResponse.getId()).thenReturn("mockTaskId"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(indexResponse); + return null; + }).when(mlTaskManager).createMLTask(any(MLTask.class), Mockito.isA(ActionListener.class)); + + when(mlModelManager.getPredictor(anyString())).thenReturn(predictor); + when(mlModelManager.getWorkerNodes(anyString(), eq(FunctionName.REMOTE), eq(true))).thenReturn(new String[] { "node1" }); + taskRunner.dispatchTask(FunctionName.REMOTE, remoteInputRequest, transportService, listener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(argumentCaptor.capture()); + assertEquals("Unable to create batch transform job", argumentCaptor.getValue().getMessage()); + } + public void testValidateModelTensorOutputFailed() { exceptionRule.expect(OpenSearchStatusException.class); ModelTensor modelTensor = ModelTensor