Skip to content

Commit

Permalink
support get batch transform job status in get task API (opensearch-pr…
Browse files Browse the repository at this point in the history
…oject#2825) (opensearch-project#2893)

* support get batch transform job status in get task API

Signed-off-by: Bhavana Ramaram <[email protected]>

* add cancel batch prediction job API for offline inference

Signed-off-by: Bhavana Ramaram <[email protected]>

* add unit tests and address comments

Signed-off-by: Bhavana Ramaram <[email protected]>

* stash context for get model

Signed-off-by: Bhavana Ramaram <[email protected]>

* apply spotlessJava and exclude from test coverage

Signed-off-by: Bhavana Ramaram <[email protected]>

---------

Signed-off-by: Bhavana Ramaram <[email protected]>
(cherry picked from commit 8da7bd2)

Co-authored-by: Bhavana Ramaram <[email protected]>
  • Loading branch information
opensearch-trigger-bot[bot] and rbhavna authored Sep 5, 2024
1 parent a91843c commit 224f8fc
Show file tree
Hide file tree
Showing 22 changed files with 1,650 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public class CommonValue {
public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2;
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 11;
public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector";
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2;
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 3;
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 3;
public static final String ML_CONFIG_INDEX = ".plugins-ml-config";
public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 3;
Expand Down Expand Up @@ -393,6 +393,10 @@ public class CommonValue {
+ "\" : {\"type\" : \"boolean\"}, \n"
+ USER_FIELD_MAPPING
+ " }\n"
+ "}"
+ MLTask.REMOTE_JOB_FIELD
+ "\" : {\"type\": \"flat_object\"}\n"
+ " }\n"
+ "}";

public static final String ML_CONNECTOR_INDEX_MAPPING = "{\n"
Expand Down
33 changes: 32 additions & 1 deletion common/src/main/java/org/opensearch/ml/common/MLTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

import org.opensearch.Version;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -45,6 +47,8 @@ public class MLTask implements ToXContentObject, Writeable {
public static final String LAST_UPDATE_TIME_FIELD = "last_update_time";
public static final String ERROR_FIELD = "error";
public static final String IS_ASYNC_TASK_FIELD = "is_async";
public static final String REMOTE_JOB_FIELD = "remote_job";
public static final Version MINIMAL_SUPPORTED_VERSION_FOR_BATCH_PREDICTION_JOB = CommonValue.VERSION_2_17_0;

@Setter
private String taskId;
Expand All @@ -66,6 +70,8 @@ public class MLTask implements ToXContentObject, Writeable {
private String error;
private User user; // TODO: support document level access control later
private boolean async;
@Setter
private Map<String, Object> remoteJob;

@Builder(toBuilder = true)
public MLTask(
Expand All @@ -82,7 +88,8 @@ public MLTask(
Instant lastUpdateTime,
String error,
User user,
boolean async
boolean async,
Map<String, Object> remoteJob
) {
this.taskId = taskId;
this.modelId = modelId;
Expand All @@ -98,9 +105,11 @@ public MLTask(
this.error = error;
this.user = user;
this.async = async;
this.remoteJob = remoteJob;
}

public MLTask(StreamInput input) throws IOException {
Version streamInputVersion = input.getVersion();
this.taskId = input.readOptionalString();
this.modelId = input.readOptionalString();
this.taskType = input.readEnum(MLTaskType.class);
Expand All @@ -123,10 +132,16 @@ public MLTask(StreamInput input) throws IOException {
this.user = null;
}
this.async = input.readBoolean();
if (streamInputVersion.onOrAfter(MLTask.MINIMAL_SUPPORTED_VERSION_FOR_BATCH_PREDICTION_JOB)) {
if (input.readBoolean()) {
this.remoteJob = input.readMap(s -> s.readString(), s -> s.readGenericValue());
}
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
Version streamOutputVersion = out.getVersion();
out.writeOptionalString(taskId);
out.writeOptionalString(modelId);
out.writeEnum(taskType);
Expand All @@ -150,6 +165,14 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
}
out.writeBoolean(async);
if (streamOutputVersion.onOrAfter(MLTask.MINIMAL_SUPPORTED_VERSION_FOR_BATCH_PREDICTION_JOB)) {
if (remoteJob != null) {
out.writeBoolean(true);
out.writeMap(remoteJob, StreamOutput::writeString, StreamOutput::writeGenericValue);
} else {
out.writeBoolean(false);
}
}
}

@Override
Expand Down Expand Up @@ -195,6 +218,9 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
builder.field(USER, user);
}
builder.field(IS_ASYNC_TASK_FIELD, async);
if (remoteJob != null) {
builder.field(REMOTE_JOB_FIELD, remoteJob);
}
return builder.endObject();
}

Expand All @@ -218,6 +244,7 @@ public static MLTask parse(XContentParser parser) throws IOException {
String error = null;
User user = null;
boolean async = false;
Map<String, Object> remoteJob = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -275,6 +302,9 @@ public static MLTask parse(XContentParser parser) throws IOException {
case IS_ASYNC_TASK_FIELD:
async = parser.booleanValue();
break;
case REMOTE_JOB_FIELD:
remoteJob = parser.map();
break;
default:
parser.skipChildren();
break;
Expand All @@ -296,6 +326,7 @@ public static MLTask parse(XContentParser parser) throws IOException {
.error(error)
.user(user)
.async(async)
.remoteJob(remoteJob)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
public enum MLTaskType {
TRAINING,
PREDICTION,
BATCH_PREDICTION,
TRAINING_AND_PREDICTION,
EXECUTION,
@Deprecated
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,9 @@ public static ConnectorAction parse(XContentParser parser) throws IOException {
public enum ActionType {
PREDICT,
EXECUTE,
BATCH_PREDICT;
BATCH_PREDICT,
CANCEL_BATCH_PREDICT,
BATCH_PREDICT_STATUS;

public static ActionType from(String value) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ public ModelTensors(List<ModelTensor> mlModelTensors) {
this.mlModelTensors = mlModelTensors;
}

@Builder
public ModelTensors(Integer statusCode) {
this.statusCode = statusCode;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.task;

import org.opensearch.action.ActionType;

public class MLCancelBatchJobAction extends ActionType<MLCancelBatchJobResponse> {
public static final MLCancelBatchJobAction INSTANCE = new MLCancelBatchJobAction();
public static final String NAME = "cluster:admin/opensearch/ml/tasks/cancel_batch_job";

private MLCancelBatchJobAction() {
super(NAME, MLCancelBatchJobResponse::new);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.task;

import static org.opensearch.action.ValidateActions.addValidationError;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;

import lombok.Builder;
import lombok.Getter;

public class MLCancelBatchJobRequest extends ActionRequest {
@Getter
String taskId;

@Builder
public MLCancelBatchJobRequest(String taskId) {
this.taskId = taskId;
}

public MLCancelBatchJobRequest(StreamInput in) throws IOException {
super(in);
this.taskId = in.readString();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(this.taskId);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;

if (this.taskId == null) {
exception = addValidationError("ML task id can't be null", exception);
}

return exception;
}

public static MLCancelBatchJobRequest fromActionRequest(ActionRequest actionRequest) {
if (actionRequest instanceof MLCancelBatchJobRequest) {
return (MLCancelBatchJobRequest) actionRequest;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLCancelBatchJobRequest(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionRequest into MLCancelBatchJobRequest", e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.task;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;

import lombok.Builder;
import lombok.Getter;

@Getter
public class MLCancelBatchJobResponse extends ActionResponse implements ToXContentObject {

RestStatus status;

@Builder
public MLCancelBatchJobResponse(RestStatus status) {
this.status = status;
}

public MLCancelBatchJobResponse(StreamInput in) throws IOException {
super(in);
status = in.readEnum(RestStatus.class);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeEnum(status);
}

public static MLCancelBatchJobResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof MLCancelBatchJobResponse) {
return (MLCancelBatchJobResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLCancelBatchJobResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into MLTaskGetResponse", e);
}
}

@Override
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
return xContentBuilder.startObject().field("status", status).endObject();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package org.opensearch.ml.common.transport.task;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotSame;

import java.io.IOException;
import java.io.UncheckedIOException;

import org.junit.Before;
import org.junit.Test;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.io.stream.StreamOutput;

public class MLCancelBatchJobRequestTest {
private String taskId;

@Before
public void setUp() {
taskId = "test_id";
}

@Test
public void writeTo_Success() throws IOException {
MLCancelBatchJobRequest mlCancelBatchJobRequest = MLCancelBatchJobRequest.builder().taskId(taskId).build();
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
mlCancelBatchJobRequest.writeTo(bytesStreamOutput);
MLCancelBatchJobRequest parsedTask = new MLCancelBatchJobRequest(bytesStreamOutput.bytes().streamInput());
assertEquals(parsedTask.getTaskId(), taskId);
}

@Test
public void validate_Exception_NullTaskId() {
MLCancelBatchJobRequest mlCancelBatchJobRequest = MLCancelBatchJobRequest.builder().build();

ActionRequestValidationException exception = mlCancelBatchJobRequest.validate();
assertEquals("Validation Failed: 1: ML task id can't be null;", exception.getMessage());
}

@Test
public void fromActionRequest_Success() {
MLCancelBatchJobRequest mlCancelBatchJobRequest = MLCancelBatchJobRequest.builder().taskId(taskId).build();
ActionRequest actionRequest = new ActionRequest() {
@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
mlCancelBatchJobRequest.writeTo(out);
}
};
MLCancelBatchJobRequest result = MLCancelBatchJobRequest.fromActionRequest(actionRequest);
assertNotSame(result, mlCancelBatchJobRequest);
assertEquals(result.getTaskId(), mlCancelBatchJobRequest.getTaskId());
}

@Test(expected = UncheckedIOException.class)
public void fromActionRequest_IOException() {
ActionRequest actionRequest = new ActionRequest() {
@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
throw new IOException("test");
}
};
MLCancelBatchJobRequest.fromActionRequest(actionRequest);
}
}
Loading

0 comments on commit 224f8fc

Please sign in to comment.