forked from opensearch-project/ml-commons
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support get batch transform job status in get task API (opensearch-pr…
…oject#2825) (opensearch-project#2893) * support get batch transform job status in get task API Signed-off-by: Bhavana Ramaram <[email protected]> * add cancel batch prediction job API for offline inference Signed-off-by: Bhavana Ramaram <[email protected]> * add unit tests and address comments Signed-off-by: Bhavana Ramaram <[email protected]> * stash context for get model Signed-off-by: Bhavana Ramaram <[email protected]> * apply spotlessJava and exclude from test coverage Signed-off-by: Bhavana Ramaram <[email protected]> --------- Signed-off-by: Bhavana Ramaram <[email protected]> (cherry picked from commit 8da7bd2) Co-authored-by: Bhavana Ramaram <[email protected]>
- Loading branch information
1 parent
a91843c
commit 224f8fc
Showing
22 changed files
with
1,650 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
17 changes: 17 additions & 0 deletions
17
common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.transport.task; | ||
|
||
import org.opensearch.action.ActionType; | ||
|
||
public class MLCancelBatchJobAction extends ActionType<MLCancelBatchJobResponse> { | ||
public static final MLCancelBatchJobAction INSTANCE = new MLCancelBatchJobAction(); | ||
public static final String NAME = "cluster:admin/opensearch/ml/tasks/cancel_batch_job"; | ||
|
||
private MLCancelBatchJobAction() { | ||
super(NAME, MLCancelBatchJobResponse::new); | ||
} | ||
} |
70 changes: 70 additions & 0 deletions
70
common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.transport.task; | ||
|
||
import static org.opensearch.action.ValidateActions.addValidationError; | ||
|
||
import java.io.ByteArrayInputStream; | ||
import java.io.ByteArrayOutputStream; | ||
import java.io.IOException; | ||
import java.io.UncheckedIOException; | ||
|
||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.ActionRequestValidationException; | ||
import org.opensearch.core.common.io.stream.InputStreamStreamInput; | ||
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
|
||
import lombok.Builder; | ||
import lombok.Getter; | ||
|
||
public class MLCancelBatchJobRequest extends ActionRequest { | ||
@Getter | ||
String taskId; | ||
|
||
@Builder | ||
public MLCancelBatchJobRequest(String taskId) { | ||
this.taskId = taskId; | ||
} | ||
|
||
public MLCancelBatchJobRequest(StreamInput in) throws IOException { | ||
super(in); | ||
this.taskId = in.readString(); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
out.writeString(this.taskId); | ||
} | ||
|
||
@Override | ||
public ActionRequestValidationException validate() { | ||
ActionRequestValidationException exception = null; | ||
|
||
if (this.taskId == null) { | ||
exception = addValidationError("ML task id can't be null", exception); | ||
} | ||
|
||
return exception; | ||
} | ||
|
||
public static MLCancelBatchJobRequest fromActionRequest(ActionRequest actionRequest) { | ||
if (actionRequest instanceof MLCancelBatchJobRequest) { | ||
return (MLCancelBatchJobRequest) actionRequest; | ||
} | ||
|
||
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { | ||
actionRequest.writeTo(osso); | ||
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { | ||
return new MLCancelBatchJobRequest(input); | ||
} | ||
} catch (IOException e) { | ||
throw new UncheckedIOException("failed to parse ActionRequest into MLCancelBatchJobRequest", e); | ||
} | ||
} | ||
} |
64 changes: 64 additions & 0 deletions
64
common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobResponse.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.transport.task; | ||
|
||
import java.io.ByteArrayInputStream; | ||
import java.io.ByteArrayOutputStream; | ||
import java.io.IOException; | ||
import java.io.UncheckedIOException; | ||
|
||
import org.opensearch.core.action.ActionResponse; | ||
import org.opensearch.core.common.io.stream.InputStreamStreamInput; | ||
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.core.rest.RestStatus; | ||
import org.opensearch.core.xcontent.ToXContentObject; | ||
import org.opensearch.core.xcontent.XContentBuilder; | ||
|
||
import lombok.Builder; | ||
import lombok.Getter; | ||
|
||
@Getter | ||
public class MLCancelBatchJobResponse extends ActionResponse implements ToXContentObject { | ||
|
||
RestStatus status; | ||
|
||
@Builder | ||
public MLCancelBatchJobResponse(RestStatus status) { | ||
this.status = status; | ||
} | ||
|
||
public MLCancelBatchJobResponse(StreamInput in) throws IOException { | ||
super(in); | ||
status = in.readEnum(RestStatus.class); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
out.writeEnum(status); | ||
} | ||
|
||
public static MLCancelBatchJobResponse fromActionResponse(ActionResponse actionResponse) { | ||
if (actionResponse instanceof MLCancelBatchJobResponse) { | ||
return (MLCancelBatchJobResponse) actionResponse; | ||
} | ||
|
||
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { | ||
actionResponse.writeTo(osso); | ||
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { | ||
return new MLCancelBatchJobResponse(input); | ||
} | ||
} catch (IOException e) { | ||
throw new UncheckedIOException("failed to parse ActionResponse into MLTaskGetResponse", e); | ||
} | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { | ||
return xContentBuilder.startObject().field("status", status).endObject(); | ||
} | ||
} |
75 changes: 75 additions & 0 deletions
75
...on/src/test/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobRequestTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
package org.opensearch.ml.common.transport.task; | ||
|
||
import static org.junit.Assert.assertEquals; | ||
import static org.junit.Assert.assertNotSame; | ||
|
||
import java.io.IOException; | ||
import java.io.UncheckedIOException; | ||
|
||
import org.junit.Before; | ||
import org.junit.Test; | ||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.ActionRequestValidationException; | ||
import org.opensearch.common.io.stream.BytesStreamOutput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
|
||
public class MLCancelBatchJobRequestTest { | ||
private String taskId; | ||
|
||
@Before | ||
public void setUp() { | ||
taskId = "test_id"; | ||
} | ||
|
||
@Test | ||
public void writeTo_Success() throws IOException { | ||
MLCancelBatchJobRequest mlCancelBatchJobRequest = MLCancelBatchJobRequest.builder().taskId(taskId).build(); | ||
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); | ||
mlCancelBatchJobRequest.writeTo(bytesStreamOutput); | ||
MLCancelBatchJobRequest parsedTask = new MLCancelBatchJobRequest(bytesStreamOutput.bytes().streamInput()); | ||
assertEquals(parsedTask.getTaskId(), taskId); | ||
} | ||
|
||
@Test | ||
public void validate_Exception_NullTaskId() { | ||
MLCancelBatchJobRequest mlCancelBatchJobRequest = MLCancelBatchJobRequest.builder().build(); | ||
|
||
ActionRequestValidationException exception = mlCancelBatchJobRequest.validate(); | ||
assertEquals("Validation Failed: 1: ML task id can't be null;", exception.getMessage()); | ||
} | ||
|
||
@Test | ||
public void fromActionRequest_Success() { | ||
MLCancelBatchJobRequest mlCancelBatchJobRequest = MLCancelBatchJobRequest.builder().taskId(taskId).build(); | ||
ActionRequest actionRequest = new ActionRequest() { | ||
@Override | ||
public ActionRequestValidationException validate() { | ||
return null; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
mlCancelBatchJobRequest.writeTo(out); | ||
} | ||
}; | ||
MLCancelBatchJobRequest result = MLCancelBatchJobRequest.fromActionRequest(actionRequest); | ||
assertNotSame(result, mlCancelBatchJobRequest); | ||
assertEquals(result.getTaskId(), mlCancelBatchJobRequest.getTaskId()); | ||
} | ||
|
||
@Test(expected = UncheckedIOException.class) | ||
public void fromActionRequest_IOException() { | ||
ActionRequest actionRequest = new ActionRequest() { | ||
@Override | ||
public ActionRequestValidationException validate() { | ||
return null; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
throw new IOException("test"); | ||
} | ||
}; | ||
MLCancelBatchJobRequest.fromActionRequest(actionRequest); | ||
} | ||
} |
Oops, something went wrong.