diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaAction.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaAction.java deleted file mode 100644 index 6b9eba971c..0000000000 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaAction.java +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.transport.upload_chunk; - -import org.opensearch.action.ActionType; - -public class MLCreateModelMetaAction extends ActionType { - public static MLCreateModelMetaAction INSTANCE = new MLCreateModelMetaAction(); - public static final String NAME = "cluster:admin/opensearch/ml/create_model_meta"; - - private MLCreateModelMetaAction() { - super(NAME, MLCreateModelMetaResponse::new); - } - -} \ No newline at end of file diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaAction.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaAction.java new file mode 100644 index 0000000000..3ee8b66805 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaAction.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.upload_chunk; + +import org.opensearch.action.ActionType; + +public class MLRegisterModelMetaAction extends ActionType { + public static MLRegisterModelMetaAction INSTANCE = new MLRegisterModelMetaAction(); + public static final String NAME = "cluster:admin/opensearch/ml/register_model_meta"; + + private MLRegisterModelMetaAction() { + super(NAME, MLRegisterModelMetaResponse::new); + } + +} \ No newline at end of file diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaInput.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java similarity index 91% rename from common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaInput.java rename to common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java index b9d171af5a..482704ed14 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java @@ -24,7 +24,7 @@ import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; @Data -public class MLCreateModelMetaInput implements ToXContentObject, Writeable{ +public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{ public static final String FUNCTION_NAME_FIELD = "function_name"; public static final String MODEL_NAME_FIELD = "name"; //mandatory @@ -52,7 +52,7 @@ public class MLCreateModelMetaInput implements ToXContentObject, Writeable{ private Integer totalChunks; @Builder(toBuilder = true) - public MLCreateModelMetaInput(String name, FunctionName functionName, String version, String description, MLModelFormat modelFormat, MLModelState modelState, Long modelContentSizeInBytes, String modelContentHashValue, MLModelConfig modelConfig, Integer totalChunks) { + public MLRegisterModelMetaInput(String name, FunctionName functionName, String version, String description, MLModelFormat modelFormat, MLModelState modelState, Long modelContentSizeInBytes, String modelContentHashValue, MLModelConfig modelConfig, Integer totalChunks) { if (name == null) { throw new IllegalArgumentException("model name is null"); } @@ -87,7 +87,7 @@ public MLCreateModelMetaInput(String name, FunctionName functionName, String ver this.totalChunks = totalChunks; } - public MLCreateModelMetaInput(StreamInput in) throws IOException{ + public MLRegisterModelMetaInput(StreamInput in) throws IOException{ this.name = in.readString(); this.functionName = in.readEnum(FunctionName.class); this.version = in.readString(); @@ -158,7 +158,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } - public static MLCreateModelMetaInput parse(XContentParser parser) throws IOException { + public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOException { String name = null; FunctionName functionName = null; String version = null; @@ -210,7 +210,7 @@ public static MLCreateModelMetaInput parse(XContentParser parser) throws IOExcep break; } } - return new MLCreateModelMetaInput(name, functionName, version, description, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks); + return new MLRegisterModelMetaInput(name, functionName, version, description, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequest.java similarity index 67% rename from common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaRequest.java rename to common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequest.java index 14a4c69aaf..c07b2ec53c 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequest.java @@ -27,24 +27,24 @@ @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @ToString -public class MLCreateModelMetaRequest extends ActionRequest { +public class MLRegisterModelMetaRequest extends ActionRequest { - MLCreateModelMetaInput mlCreateModelMetaInput; + MLRegisterModelMetaInput mlRegisterModelMetaInput; @Builder - public MLCreateModelMetaRequest(MLCreateModelMetaInput mlCreateModelMetaInput) { - this.mlCreateModelMetaInput = mlCreateModelMetaInput; + public MLRegisterModelMetaRequest(MLRegisterModelMetaInput mlRegisterModelMetaInput) { + this.mlRegisterModelMetaInput = mlRegisterModelMetaInput; } - public MLCreateModelMetaRequest(StreamInput in) throws IOException { + public MLRegisterModelMetaRequest(StreamInput in) throws IOException { super(in); - this.mlCreateModelMetaInput = new MLCreateModelMetaInput(in); + this.mlRegisterModelMetaInput = new MLRegisterModelMetaInput(in); } @Override public ActionRequestValidationException validate() { ActionRequestValidationException exception = null; - if (mlCreateModelMetaInput == null) { + if (mlRegisterModelMetaInput == null) { exception = addValidationError("Model meta input can't be null", exception); } @@ -54,22 +54,22 @@ public ActionRequestValidationException validate() { @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - this.mlCreateModelMetaInput.writeTo(out); + this.mlRegisterModelMetaInput.writeTo(out); } - public static MLCreateModelMetaRequest fromActionRequest(ActionRequest actionRequest) { - if (actionRequest instanceof MLCreateModelMetaRequest) { - return (MLCreateModelMetaRequest) actionRequest; + public static MLRegisterModelMetaRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLRegisterModelMetaRequest) { + return (MLRegisterModelMetaRequest) actionRequest; } try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { - return new MLCreateModelMetaRequest(input); + return new MLRegisterModelMetaRequest(input); } } catch (IOException e) { - throw new UncheckedIOException("Failed to parse ActionRequest into MLCreateModelMetaRequest", e); + throw new UncheckedIOException("Failed to parse ActionRequest into MLRegisterModelMetaRequest", e); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaResponse.java similarity index 81% rename from common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaResponse.java rename to common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaResponse.java index af81d2a860..393aa1754c 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaResponse.java @@ -5,6 +5,7 @@ package org.opensearch.ml.common.transport.upload_chunk; +import lombok.Getter; import org.opensearch.action.ActionResponse; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; @@ -13,21 +14,23 @@ import java.io.IOException; -public class MLCreateModelMetaResponse extends ActionResponse implements ToXContentObject { +public class MLRegisterModelMetaResponse extends ActionResponse implements ToXContentObject { public static final String MODEL_ID_FIELD = "model_id"; public static final String STATUS_FIELD = "status"; + @Getter private String modelId; + @Getter private String status; - public MLCreateModelMetaResponse(StreamInput in) throws IOException { + public MLRegisterModelMetaResponse(StreamInput in) throws IOException { super(in); this.modelId = in.readString(); this.status = in.readString(); } - public MLCreateModelMetaResponse(String modelId, String status) { + public MLRegisterModelMetaResponse(String modelId, String status) { this.modelId = modelId; this.status= status; } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkResponse.java index c4ca1febfd..23d83fbb5b 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkResponse.java @@ -5,6 +5,7 @@ package org.opensearch.ml.common.transport.upload_chunk; +import lombok.Getter; import org.opensearch.action.ActionResponse; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; @@ -15,6 +16,7 @@ public class MLUploadModelChunkResponse extends ActionResponse implements ToXContentObject { public static final String STATUS_FIELD = "status"; + @Getter private String status; public MLUploadModelChunkResponse (StreamInput in) throws IOException { diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaRequestTest.java deleted file mode 100644 index 40b52d4451..0000000000 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaRequestTest.java +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.transport.upload_chunk; - -import org.junit.Before; -import org.junit.Test; -import org.opensearch.action.ActionRequest; -import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.model.MLModelFormat; -import org.opensearch.ml.common.model.MLModelState; -import org.opensearch.ml.common.model.TextEmbeddingModelConfig; -import org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType; - -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; - -public class MLCreateModelMetaRequestTest { - - TextEmbeddingModelConfig config; - MLCreateModelMetaInput mlCreateModelMetaInput; - - @Before - public void setUp() { - config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config", - TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512); - mlCreateModelMetaInput = new MLCreateModelMetaInput("Model Name", FunctionName.BATCH_RCF, "1.0", - "Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2); - } - - @Test - public void writeTo_Succeess() throws IOException { - MLCreateModelMetaRequest request = new MLCreateModelMetaRequest(mlCreateModelMetaInput); - BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - request.writeTo(bytesStreamOutput); - MLCreateModelMetaRequest newRequest = new MLCreateModelMetaRequest(bytesStreamOutput.bytes().streamInput()); - assertEquals(request.getMlCreateModelMetaInput().getName(), newRequest.getMlCreateModelMetaInput().getName()); - assertEquals(request.getMlCreateModelMetaInput().getDescription(), - newRequest.getMlCreateModelMetaInput().getDescription()); - assertEquals(request.getMlCreateModelMetaInput().getFunctionName(), - newRequest.getMlCreateModelMetaInput().getFunctionName()); - assertEquals(request.getMlCreateModelMetaInput().getModelConfig().getAllConfig(), - newRequest.getMlCreateModelMetaInput().getModelConfig().getAllConfig()); - assertEquals(request.getMlCreateModelMetaInput().getVersion(), - newRequest.getMlCreateModelMetaInput().getVersion()); - } - - @Test - public void validate_Exception_NullModelId() { - MLCreateModelMetaRequest mlCreateModelMetaRequest = MLCreateModelMetaRequest.builder().build(); - ActionRequestValidationException exception = mlCreateModelMetaRequest.validate(); - assertEquals("Validation Failed: 1: Model meta input can't be null;", exception.getMessage()); - } - - @Test - public void fromActionRequest_Success() { - MLCreateModelMetaRequest request = new MLCreateModelMetaRequest(mlCreateModelMetaInput); - ActionRequest actionRequest = new ActionRequest() { - @Override - public ActionRequestValidationException validate() { - return null; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - request.writeTo(out); - } - }; - MLCreateModelMetaRequest newRequest = MLCreateModelMetaRequest.fromActionRequest(actionRequest); - assertNotSame(request, newRequest); - assertEquals(request.getMlCreateModelMetaInput().getName(), newRequest.getMlCreateModelMetaInput().getName()); - assertEquals(request.getMlCreateModelMetaInput().getDescription(), - newRequest.getMlCreateModelMetaInput().getDescription()); - assertEquals(request.getMlCreateModelMetaInput().getFunctionName(), - newRequest.getMlCreateModelMetaInput().getFunctionName()); - assertEquals(request.getMlCreateModelMetaInput().getModelConfig().getAllConfig(), - newRequest.getMlCreateModelMetaInput().getModelConfig().getAllConfig()); - assertEquals(request.getMlCreateModelMetaInput().getVersion(), - newRequest.getMlCreateModelMetaInput().getVersion()); - } - - @Test(expected = UncheckedIOException.class) - public void fromActionRequest_IOException() { - ActionRequest actionRequest = new ActionRequest() { - @Override - public ActionRequestValidationException validate() { - return null; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - throw new IOException("test"); - } - }; - MLCreateModelMetaRequest.fromActionRequest(actionRequest); - } -} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java similarity index 76% rename from common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaInputTest.java rename to common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java index 8f50306cf6..1d6f60b89f 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java @@ -25,43 +25,43 @@ import static org.junit.Assert.assertEquals; import static org.opensearch.common.xcontent.ToXContent.EMPTY_PARAMS; -public class MLCreateModelMetaInputTest { +public class MLRegisterModelMetaInputTest { - Function function = parser -> { + Function function = parser -> { try { - return MLCreateModelMetaInput.parse(parser); + return MLRegisterModelMetaInput.parse(parser); } catch (Exception e) { - throw new RuntimeException("Failed to parse MLCreateModelMetaInput", e); + throw new RuntimeException("Failed to parse MLRegisterModelMetaInput", e); } }; TextEmbeddingModelConfig config; - MLCreateModelMetaInput mLCreateModelMetaInput; + MLRegisterModelMetaInput mLRegisterModelMetaInput; @Before public void setup() { config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config", TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512); - mLCreateModelMetaInput = new MLCreateModelMetaInput("Model Name", FunctionName.BATCH_RCF, "1.0", + mLRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "1.0", "Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2); } @Test - public void parse_MLCreateModelMetaInput() throws IOException { - TestHelper.testParse(mLCreateModelMetaInput, function); + public void parse_MLRegisterModelMetaInput() throws IOException { + TestHelper.testParse(mLRegisterModelMetaInput, function); } @Test public void readInputStream_Success() throws IOException { - readInputStream(mLCreateModelMetaInput); + readInputStream(mLRegisterModelMetaInput); } - private void readInputStream(MLCreateModelMetaInput input) throws IOException { + private void readInputStream(MLRegisterModelMetaInput input) throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); input.writeTo(bytesStreamOutput); StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); - MLCreateModelMetaInput newInput = new MLCreateModelMetaInput(streamInput); + MLRegisterModelMetaInput newInput = new MLRegisterModelMetaInput(streamInput); assertEquals(input.getName(), newInput.getName()); assertEquals(input.getDescription(), newInput.getDescription()); assertEquals(input.getModelFormat(), newInput.getModelFormat()); @@ -73,7 +73,7 @@ private void readInputStream(MLCreateModelMetaInput input) throws IOException { @Test public void testToXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); - mLCreateModelMetaInput.toXContent(builder, EMPTY_PARAMS); + mLRegisterModelMetaInput.toXContent(builder, EMPTY_PARAMS); String mlModelContent = TestHelper.xContentBuilderToString(builder); final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"version\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," + "\"embedding_dimension\":123,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\",\"model_max_length\":512,\"pooling_mode\":\"MEAN\",\"normalize_result\":true},\"total_chunks\":2}"; diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java new file mode 100644 index 0000000000..5a24a4dbd7 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java @@ -0,0 +1,105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.upload_chunk; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; + +public class MLRegisterModelMetaRequestTest { + + TextEmbeddingModelConfig config; + MLRegisterModelMetaInput mlRegisterModelMetaInput; + + @Before + public void setUp() { + config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config", + TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512); + mlRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "1.0", + "Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2); + } + + @Test + public void writeTo_Succeess() throws IOException { + MLRegisterModelMetaRequest request = new MLRegisterModelMetaRequest(mlRegisterModelMetaInput); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + request.writeTo(bytesStreamOutput); + MLRegisterModelMetaRequest newRequest = new MLRegisterModelMetaRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals(request.getMlRegisterModelMetaInput().getName(), newRequest.getMlRegisterModelMetaInput().getName()); + assertEquals(request.getMlRegisterModelMetaInput().getDescription(), + newRequest.getMlRegisterModelMetaInput().getDescription()); + assertEquals(request.getMlRegisterModelMetaInput().getFunctionName(), + newRequest.getMlRegisterModelMetaInput().getFunctionName()); + assertEquals(request.getMlRegisterModelMetaInput().getModelConfig().getAllConfig(), + newRequest.getMlRegisterModelMetaInput().getModelConfig().getAllConfig()); + assertEquals(request.getMlRegisterModelMetaInput().getVersion(), + newRequest.getMlRegisterModelMetaInput().getVersion()); + } + + @Test + public void validate_Exception_NullModelId() { + MLRegisterModelMetaRequest mlRegisterModelMetaRequest = MLRegisterModelMetaRequest.builder().build(); + ActionRequestValidationException exception = mlRegisterModelMetaRequest.validate(); + assertEquals("Validation Failed: 1: Model meta input can't be null;", exception.getMessage()); + } + + @Test + public void fromActionRequest_Success() { + MLRegisterModelMetaRequest request = new MLRegisterModelMetaRequest(mlRegisterModelMetaInput); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + request.writeTo(out); + } + }; + MLRegisterModelMetaRequest newRequest = MLRegisterModelMetaRequest.fromActionRequest(actionRequest); + assertNotSame(request, newRequest); + assertEquals(request.getMlRegisterModelMetaInput().getName(), newRequest.getMlRegisterModelMetaInput().getName()); + assertEquals(request.getMlRegisterModelMetaInput().getDescription(), + newRequest.getMlRegisterModelMetaInput().getDescription()); + assertEquals(request.getMlRegisterModelMetaInput().getFunctionName(), + newRequest.getMlRegisterModelMetaInput().getFunctionName()); + assertEquals(request.getMlRegisterModelMetaInput().getModelConfig().getAllConfig(), + newRequest.getMlRegisterModelMetaInput().getModelConfig().getAllConfig()); + assertEquals(request.getMlRegisterModelMetaInput().getVersion(), + newRequest.getMlRegisterModelMetaInput().getVersion()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequest_IOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("test"); + } + }; + MLRegisterModelMetaRequest.fromActionRequest(actionRequest); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaResponseTest.java similarity index 62% rename from common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaResponseTest.java rename to common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaResponseTest.java index 77da7a5870..134fd1792c 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaResponseTest.java @@ -18,27 +18,28 @@ import org.opensearch.common.xcontent.XContentType; import org.opensearch.ml.common.TestHelper; -public class MLCreateModelMetaResponseTest { +public class MLRegisterModelMetaResponseTest { - MLCreateModelMetaResponse mlCreateModelMetaResponse; + MLRegisterModelMetaResponse mlRegisterModelMetaResponse; @Before public void setup() { - mlCreateModelMetaResponse = new MLCreateModelMetaResponse("Model Id", "Status"); + mlRegisterModelMetaResponse = new MLRegisterModelMetaResponse("Model Id", "Status"); } @Test public void writeTo_Success() throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - mlCreateModelMetaResponse.writeTo(bytesStreamOutput); - MLCreateModelMetaResponse newResponse = new MLCreateModelMetaResponse(bytesStreamOutput.bytes().streamInput()); -// assertEquals(mlCreateModelMetaResponse, newResponse); + mlRegisterModelMetaResponse.writeTo(bytesStreamOutput); + MLRegisterModelMetaResponse newResponse = new MLRegisterModelMetaResponse(bytesStreamOutput.bytes().streamInput()); + assertEquals(mlRegisterModelMetaResponse.getModelId(), newResponse.getModelId()); + assertEquals(mlRegisterModelMetaResponse.getStatus(), newResponse.getStatus()); } @Test public void testToXContent() throws IOException { - MLCreateModelMetaResponse response = new MLCreateModelMetaResponse("Model Id", "Status"); + MLRegisterModelMetaResponse response = new MLRegisterModelMetaResponse("Model Id", "Status"); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); response.toXContent(builder, EMPTY_PARAMS); assertNotNull(builder); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkResponseTest.java index 2c630f7051..fcf00f2d67 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkResponseTest.java @@ -27,13 +27,12 @@ public void setup() { mlUploadModelChunkResponse = new MLUploadModelChunkResponse("Status"); } - @Test public void writeTo_Success() throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); mlUploadModelChunkResponse.writeTo(bytesStreamOutput); MLUploadModelChunkResponse newResponse = new MLUploadModelChunkResponse(bytesStreamOutput.bytes().streamInput()); -// assertEquals(response.getStatus(), newResponse.getStatus()); + assertEquals(mlUploadModelChunkResponse.getStatus(), newResponse.getStatus()); } @Test diff --git a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelMetaCreate.java b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelMetaCreate.java deleted file mode 100644 index 39288ce756..0000000000 --- a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelMetaCreate.java +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.action.upload_chunk; - -import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; - -import java.time.Instant; - -import lombok.extern.log4j.Log4j2; - -import org.opensearch.action.ActionListener; -import org.opensearch.action.index.IndexRequest; -import org.opensearch.action.support.WriteRequest; -import org.opensearch.client.Client; -import org.opensearch.common.inject.Inject; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.ToXContent; -import org.opensearch.common.xcontent.XContentBuilder; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.MLModel; -import org.opensearch.ml.common.model.MLModelState; -import org.opensearch.ml.common.transport.upload_chunk.MLCreateModelMetaInput; -import org.opensearch.ml.indices.MLIndicesHandler; -import org.opensearch.threadpool.ThreadPool; - -@Log4j2 -public class MLModelMetaCreate { - - private final MLIndicesHandler mlIndicesHandler; - private final ThreadPool threadPool; - private final Client client; - - @Inject - public MLModelMetaCreate(MLIndicesHandler mlIndicesHandler, ThreadPool threadPool, Client client) { - this.mlIndicesHandler = mlIndicesHandler; - this.threadPool = threadPool; - this.client = client; - } - - public void createModelMeta(MLCreateModelMetaInput mlCreateModelMetaInput, ActionListener listener) { - try { - String modelName = mlCreateModelMetaInput.getName(); - String version = mlCreateModelMetaInput.getVersion(); - FunctionName functionName = mlCreateModelMetaInput.getFunctionName(); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(res -> { - Instant now = Instant.now(); - MLModel mlModelMeta = MLModel - .builder() - .name(modelName) - .algorithm(functionName) - .version(version) - .description(mlCreateModelMetaInput.getDescription()) - .modelFormat(mlCreateModelMetaInput.getModelFormat()) - .modelState(MLModelState.REGISTERING) - .modelConfig(mlCreateModelMetaInput.getModelConfig()) - .totalChunks(mlCreateModelMetaInput.getTotalChunks()) - .modelContentHash(mlCreateModelMetaInput.getModelContentHashValue()) - .modelContentSizeInBytes(mlCreateModelMetaInput.getModelContentSizeInBytes()) - .createdTime(now) - .lastUpdateTime(now) - .build(); - IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX); - indexRequest - .source(mlModelMeta.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS)); - indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - client.index(indexRequest, ActionListener.wrap(r -> { - log.debug("Index model meta doc successfully {}", modelName); - listener.onResponse(r.getId()); - }, e -> { - log.error("Failed to index model meta doc", e); - listener.onFailure(e); - })); - }, ex -> { - log.error("Failed to init model index", ex); - listener.onFailure(ex); - })); - } catch (Exception e) { - log.error("Failed to create model meta doc", e); - listener.onFailure(e); - } - } catch (final Exception e) { - log.error("Failed to init model index", e); - listener.onFailure(e); - } - } - -} diff --git a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportCreateModelMetaAction.java b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportCreateModelMetaAction.java deleted file mode 100644 index d736e9fe84..0000000000 --- a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportCreateModelMetaAction.java +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.action.upload_chunk; - -import lombok.extern.log4j.Log4j2; - -import org.opensearch.action.ActionListener; -import org.opensearch.action.ActionRequest; -import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.common.inject.Inject; -import org.opensearch.ml.common.MLTaskState; -import org.opensearch.ml.common.transport.upload_chunk.MLCreateModelMetaAction; -import org.opensearch.ml.common.transport.upload_chunk.MLCreateModelMetaInput; -import org.opensearch.ml.common.transport.upload_chunk.MLCreateModelMetaRequest; -import org.opensearch.ml.common.transport.upload_chunk.MLCreateModelMetaResponse; -import org.opensearch.tasks.Task; -import org.opensearch.transport.TransportService; - -@Log4j2 -public class TransportCreateModelMetaAction extends HandledTransportAction { - - TransportService transportService; - ActionFilters actionFilters; - MLModelMetaCreate mlModelMetaCreate; - - @Inject - public TransportCreateModelMetaAction( - TransportService transportService, - ActionFilters actionFilters, - MLModelMetaCreate mlModelMetaCreate - ) { - super(MLCreateModelMetaAction.NAME, transportService, actionFilters, MLCreateModelMetaRequest::new); - this.transportService = transportService; - this.actionFilters = actionFilters; - this.mlModelMetaCreate = mlModelMetaCreate; - } - - @Override - protected void doExecute(Task task, ActionRequest request, ActionListener listener) { - MLCreateModelMetaRequest createModelMetaRequest = MLCreateModelMetaRequest.fromActionRequest(request); - MLCreateModelMetaInput mlUploadInput = createModelMetaRequest.getMlCreateModelMetaInput(); - mlModelMetaCreate - .createModelMeta( - mlUploadInput, - ActionListener - .wrap(modelId -> { listener.onResponse(new MLCreateModelMetaResponse(modelId, MLTaskState.CREATED.name())); }, ex -> { - log.error("Failed to init model index", ex); - listener.onFailure(ex); - }) - ); - } -} diff --git a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java new file mode 100644 index 0000000000..e9962f55c2 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.upload_chunk; + +import lombok.extern.log4j.Log4j2; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaAction; +import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaInput; +import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaRequest; +import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaResponse; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +@Log4j2 +public class TransportRegisterModelMetaAction extends HandledTransportAction { + + TransportService transportService; + ActionFilters actionFilters; + MLModelManager mlModelManager; + + @Inject + public TransportRegisterModelMetaAction(TransportService transportService, ActionFilters actionFilters, MLModelManager mlModelManager) { + super(MLRegisterModelMetaAction.NAME, transportService, actionFilters, MLRegisterModelMetaRequest::new); + this.transportService = transportService; + this.actionFilters = actionFilters; + this.mlModelManager = mlModelManager; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener listener) { + MLRegisterModelMetaRequest registerModelMetaRequest = MLRegisterModelMetaRequest.fromActionRequest(request); + MLRegisterModelMetaInput mlUploadInput = registerModelMetaRequest.getMlRegisterModelMetaInput(); + mlModelManager + .registerModelMeta( + mlUploadInput, + ActionListener + .wrap(modelId -> { listener.onResponse(new MLRegisterModelMetaResponse(modelId, MLTaskState.CREATED.name())); }, ex -> { + log.error("Failed to init model index", ex); + listener.onFailure(ex); + }) + ); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 24271de003..fa909bc607 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -69,8 +69,10 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.common.xcontent.ToXContent; import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentParser; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.index.reindex.DeleteByQueryAction; import org.opensearch.index.reindex.DeleteByQueryRequest; @@ -86,6 +88,7 @@ import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaInput; import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.ModelHelper; import org.opensearch.ml.engine.Predictable; @@ -185,6 +188,55 @@ public MLModelManager( .addSettingsUpdateConsumer(ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE, it -> maxDeployTasksPerNode = it); } + public void registerModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput, ActionListener listener) { + try { + String modelName = mlRegisterModelMetaInput.getName(); + String version = mlRegisterModelMetaInput.getVersion(); + FunctionName functionName = mlRegisterModelMetaInput.getFunctionName(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(res -> { + Instant now = Instant.now(); + MLModel mlModelMeta = MLModel + .builder() + .name(modelName) + .algorithm(functionName) + .version(version) + .description(mlRegisterModelMetaInput.getDescription()) + .modelFormat(mlRegisterModelMetaInput.getModelFormat()) + .modelState(MLModelState.REGISTERING) + .modelConfig(mlRegisterModelMetaInput.getModelConfig()) + .totalChunks(mlRegisterModelMetaInput.getTotalChunks()) + .modelContentHash(mlRegisterModelMetaInput.getModelContentHashValue()) + .modelContentSizeInBytes(mlRegisterModelMetaInput.getModelContentSizeInBytes()) + .createdTime(now) + .lastUpdateTime(now) + .build(); + IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX); + indexRequest + .source(mlModelMeta.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS)); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + client.index(indexRequest, ActionListener.wrap(r -> { + log.debug("Index model meta doc successfully {}", modelName); + listener.onResponse(r.getId()); + }, e -> { + log.error("Failed to index model meta doc", e); + listener.onFailure(e); + })); + }, ex -> { + log.error("Failed to init model index", ex); + listener.onFailure(ex); + })); + } catch (Exception e) { + log.error("Failed to register model meta doc", e); + listener.onFailure(e); + } + } catch (final Exception e) { + log.error("Failed to init model index", e); + listener.onFailure(e); + } + } + /** * Register model. Basically download model file, split into chunks and save into model index. * diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index f0cee79ef4..4ddf76a9c4 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -55,8 +55,7 @@ import org.opensearch.ml.action.trainpredict.TransportTrainAndPredictionTaskAction; import org.opensearch.ml.action.undeploy.TransportUndeployModelAction; import org.opensearch.ml.action.upload_chunk.MLModelChunkUploader; -import org.opensearch.ml.action.upload_chunk.MLModelMetaCreate; -import org.opensearch.ml.action.upload_chunk.TransportCreateModelMetaAction; +import org.opensearch.ml.action.upload_chunk.TransportRegisterModelMetaAction; import org.opensearch.ml.action.upload_chunk.TransportUploadModelChunkAction; import org.opensearch.ml.breaker.MLCircuitBreakerService; import org.opensearch.ml.cluster.DiscoveryNodeHelper; @@ -90,7 +89,7 @@ import org.opensearch.ml.common.transport.training.MLTrainingTaskAction; import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelAction; -import org.opensearch.ml.common.transport.upload_chunk.MLCreateModelMetaAction; +import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaAction; import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkAction; import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.MLEngineClassLoader; @@ -101,7 +100,6 @@ import org.opensearch.ml.indices.MLInputDatasetHandler; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; -import org.opensearch.ml.rest.RestMLCreateModelMetaAction; import org.opensearch.ml.rest.RestMLDeleteModelAction; import org.opensearch.ml.rest.RestMLDeleteTaskAction; import org.opensearch.ml.rest.RestMLDeployModelAction; @@ -111,6 +109,7 @@ import org.opensearch.ml.rest.RestMLPredictionAction; import org.opensearch.ml.rest.RestMLProfileAction; import org.opensearch.ml.rest.RestMLRegisterModelAction; +import org.opensearch.ml.rest.RestMLRegisterModelMetaAction; import org.opensearch.ml.rest.RestMLSearchModelAction; import org.opensearch.ml.rest.RestMLSearchTaskAction; import org.opensearch.ml.rest.RestMLStatsAction; @@ -171,7 +170,6 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin { private ModelHelper modelHelper; private DiscoveryNodeHelper nodeHelper; - private MLModelMetaCreate mlModelMetaCreate; private MLModelChunkUploader mlModelChunkUploader; private MLEngine mlEngine; @@ -203,7 +201,7 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin { new ActionHandler<>(MLDeployModelAction.INSTANCE, TransportDeployModelAction.class), new ActionHandler<>(MLDeployModelOnNodeAction.INSTANCE, TransportDeployModelOnNodeAction.class), new ActionHandler<>(MLUndeployModelAction.INSTANCE, TransportUndeployModelAction.class), - new ActionHandler<>(MLCreateModelMetaAction.INSTANCE, TransportCreateModelMetaAction.class), + new ActionHandler<>(MLRegisterModelMetaAction.INSTANCE, TransportRegisterModelMetaAction.class), new ActionHandler<>(MLUploadModelChunkAction.INSTANCE, TransportUploadModelChunkAction.class), new ActionHandler<>(MLForwardAction.INSTANCE, TransportForwardAction.class), new ActionHandler<>(MLSyncUpAction.INSTANCE, TransportSyncUpOnNodeAction.class) @@ -273,7 +271,6 @@ public Collection createComponents( ); mlInputDatasetHandler = new MLInputDatasetHandler(client); - mlModelMetaCreate = new MLModelMetaCreate(mlIndicesHandler, threadPool, client); mlModelChunkUploader = new MLModelChunkUploader(mlIndicesHandler, client, xContentRegistry); MLTaskDispatcher mlTaskDispatcher = new MLTaskDispatcher(clusterService, client, settings, nodeHelper); @@ -369,7 +366,6 @@ public Collection createComponents( mlExecuteTaskRunner, mlSearchHandler, mlTaskDispatcher, - mlModelMetaCreate, mlModelChunkUploader, modelHelper, mlCommonsClusterEventListener, @@ -403,7 +399,7 @@ public List getRestHandlers( RestMLRegisterModelAction restMLRegisterModelAction = new RestMLRegisterModelAction(); RestMLDeployModelAction restMLDeployModelAction = new RestMLDeployModelAction(); RestMLUndeployModelAction restMLUndeployModelAction = new RestMLUndeployModelAction(clusterService); - RestMLCreateModelMetaAction restMLCreateModelMetaAction = new RestMLCreateModelMetaAction(); + RestMLRegisterModelMetaAction restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction(); RestMLUploadModelChunkAction restMLUploadModelChunkAction = new RestMLUploadModelChunkAction(); return ImmutableList @@ -423,7 +419,7 @@ public List getRestHandlers( restMLRegisterModelAction, restMLDeployModelAction, restMLUndeployModelAction, - restMLCreateModelMetaAction, + restMLRegisterModelMetaAction, restMLUploadModelChunkAction ); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateModelMetaAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelMetaAction.java similarity index 50% rename from plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateModelMetaAction.java rename to plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelMetaAction.java index 7826b898d9..e4ba6fefd1 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateModelMetaAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelMetaAction.java @@ -14,9 +14,9 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.common.xcontent.XContentParser; -import org.opensearch.ml.common.transport.upload_chunk.MLCreateModelMetaAction; -import org.opensearch.ml.common.transport.upload_chunk.MLCreateModelMetaInput; -import org.opensearch.ml.common.transport.upload_chunk.MLCreateModelMetaRequest; +import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaAction; +import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaInput; +import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaRequest; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; @@ -24,28 +24,37 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; -public class RestMLCreateModelMetaAction extends BaseRestHandler { - private static final String ML_CREATE_MODEL_META_ACTION = "ml_create_model_meta_action"; +public class RestMLRegisterModelMetaAction extends BaseRestHandler { + private static final String ML_REGISTER_MODEL_META_ACTION = "ml_register_model_meta_action"; /** * Constructor */ - public RestMLCreateModelMetaAction() {} + public RestMLRegisterModelMetaAction() {} @Override public String getName() { - return ML_CREATE_MODEL_META_ACTION; + return ML_REGISTER_MODEL_META_ACTION; } @Override - public List routes() { - return ImmutableList.of(new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/models/meta", ML_BASE_URI))); + public List replacedRoutes() { + return ImmutableList + .of( + new ReplacedRoute( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/models/_register_meta", ML_BASE_URI),// new url + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/models/meta", ML_BASE_URI)// old url + ) + ); } @Override public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { - MLCreateModelMetaRequest mlCreateModelMetaRequest = getRequest(request); - return channel -> client.execute(MLCreateModelMetaAction.INSTANCE, mlCreateModelMetaRequest, new RestToXContentListener<>(channel)); + MLRegisterModelMetaRequest mlRegisterModelMetaRequest = getRequest(request); + return channel -> client + .execute(MLRegisterModelMetaAction.INSTANCE, mlRegisterModelMetaRequest, new RestToXContentListener<>(channel)); } /** @@ -55,14 +64,14 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client * @return MLUploadModelMetaRequest */ @VisibleForTesting - MLCreateModelMetaRequest getRequest(RestRequest request) throws IOException { + MLRegisterModelMetaRequest getRequest(RestRequest request) throws IOException { boolean hasContent = request.hasContent(); if (!hasContent) { throw new IOException("Model meta request has empty body"); } XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - MLCreateModelMetaInput mlInput = MLCreateModelMetaInput.parse(parser); - return new MLCreateModelMetaRequest(mlInput); + MLRegisterModelMetaInput mlInput = MLRegisterModelMetaInput.parse(parser); + return new MLRegisterModelMetaRequest(mlInput); } } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUploadModelChunkAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUploadModelChunkAction.java index c03041021e..0e3441ada7 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUploadModelChunkAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUploadModelChunkAction.java @@ -36,12 +36,14 @@ public String getName() { } @Override - public List routes() { + public List replacedRoutes() { return ImmutableList .of( - new Route( + new ReplacedRoute( RestRequest.Method.POST, - String.format(Locale.ROOT, "%s/models/{%s}/chunk/{%s}", ML_BASE_URI, "model_id", "chunk_number") + String.format(Locale.ROOT, "%s/models/{%s}/upload_chunk/{%s}", ML_BASE_URI, "model_id", "chunk_number"),// new url + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/models/{%s}/chunk/{%s}", ML_BASE_URI, "model_id", "chunk_number")// old url ) ); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java index 7fbc1b4fa0..4688445c13 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java @@ -56,9 +56,6 @@ public class MLModelChunkUploaderTests extends OpenSearchTestCase { @Mock private Client client; - @Mock - private ActionListener getModelListener; - @Mock private ActionListener actionListener; @@ -73,9 +70,6 @@ public class MLModelChunkUploaderTests extends OpenSearchTestCase { @Mock private NamedXContentRegistry xContentRegistry; - @Mock - private MLModel mlModel; - @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); @@ -194,7 +188,7 @@ public void testUploadModelChunkSizeMorethan10MB() { assertEquals("Chunk size exceeds 10MB", argumentCaptor.getValue().getMessage()); } - public void testUploadModelChunkModelNotFound() throws IOException { + public void testUploadModelChunkModelNotFound() { MLModelChunkUploader mlModelChunkUploader = new MLModelChunkUploader(mlIndicesHandler, client, xContentRegistry); MLUploadModelChunkInput uploadModelChunkInput = prepareRequest(); uploadModelChunkInput.setChunkNumber(5); @@ -209,7 +203,7 @@ public void testUploadModelChunkModelNotFound() throws IOException { assertEquals("Failed to find model", argumentCaptor.getValue().getMessage()); } - public void testUploadModelChunkModelIndexNotFound() throws IOException { + public void testUploadModelChunkModelIndexNotFound() { MLModelChunkUploader mlModelChunkUploader = new MLModelChunkUploader(mlIndicesHandler, client, xContentRegistry); MLUploadModelChunkInput uploadModelChunkInput = prepareRequest(); uploadModelChunkInput.setChunkNumber(5); @@ -224,7 +218,7 @@ public void testUploadModelChunkModelIndexNotFound() throws IOException { assertEquals("Failed to find model", argumentCaptor.getValue().getMessage()); } - public void testUploadModelChunkIndexNotFound() throws IOException { + public void testUploadModelChunkIndexNotFound() { MLModelChunkUploader mlModelChunkUploader = new MLModelChunkUploader(mlIndicesHandler, client, xContentRegistry); MLUploadModelChunkInput uploadModelChunkInput = prepareRequest(); uploadModelChunkInput.setChunkNumber(5); diff --git a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelMetaCreateTests.java b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelMetaCreateTests.java deleted file mode 100644 index aabf745bfd..0000000000 --- a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelMetaCreateTests.java +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.action.upload_chunk; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.*; - -import java.util.concurrent.ExecutorService; - -import org.junit.Before; -import org.mockito.ArgumentCaptor; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import org.opensearch.action.ActionListener; -import org.opensearch.action.index.IndexResponse; -import org.opensearch.client.Client; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.model.MLModelFormat; -import org.opensearch.ml.common.model.TextEmbeddingModelConfig; -import org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType; -import org.opensearch.ml.common.transport.upload_chunk.MLCreateModelMetaInput; -import org.opensearch.ml.indices.MLIndicesHandler; -import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.threadpool.ThreadPool; - -public class MLModelMetaCreateTests extends OpenSearchTestCase { - - @Mock - private MLIndicesHandler mlIndicesHandler; - - @Mock - private ThreadPool threadPool; - - @Mock - private Client client; - - @Mock - private ActionListener actionListener; - - private ThreadContext threadContext; - - @Mock - private ExecutorService executorService; - - @Mock - private IndexResponse indexResponse; - - @Before - public void setup() { - MockitoAnnotations.openMocks(this); - Settings settings = Settings.builder().build(); - threadContext = new ThreadContext(settings); - when(client.threadPool()).thenReturn(threadPool); - when(threadPool.getThreadContext()).thenReturn(threadContext); - when(threadPool.executor(anyString())).thenReturn(executorService); - doAnswer(invocation -> { - Runnable runnable = invocation.getArgument(0); - runnable.run(); - return null; - }).when(executorService).execute(any(Runnable.class)); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(indexResponse); - return null; - }).when(client).index(any()); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(indexResponse); - return null; - }).when(client).index(any(), any()); - - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(0); - actionListener.onResponse(true); - return null; - }).when(mlIndicesHandler).initModelIndexIfAbsent(any()); - - } - - public void testConstructor() { - MLModelMetaCreate mlModelChunkCreate = new MLModelMetaCreate(mlIndicesHandler, threadPool, client); - assertNotNull(mlModelChunkCreate); - } - - public void testCreateModelMeta() { - MLModelMetaCreate mlModelMetaCreate = new MLModelMetaCreate(mlIndicesHandler, threadPool, client); - MLCreateModelMetaInput createModelMetaInput = prepareRequest(); - mlModelMetaCreate.createModelMeta(createModelMetaInput, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(String.class); - verify(actionListener).onResponse(argumentCaptor.capture()); - } - - public void testCreateModelMeta_FailedToInitIndex() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new Exception("Init Index Failed")); - return null; - }).when(client).index(any(), any()); - MLModelMetaCreate mlModelMetaCreate = new MLModelMetaCreate(mlIndicesHandler, threadPool, client); - MLCreateModelMetaInput createModelMetaInput = prepareRequest(); - mlModelMetaCreate.createModelMeta(createModelMetaInput, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); - verify(actionListener).onFailure(argumentCaptor.capture()); - } - - public void testCreateModelMeta_FailedToInitIndexIfPresent() { - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(0); - actionListener.onFailure(new Exception("initModelIndexIfAbsent Failed")); - return null; - }).when(mlIndicesHandler).initModelIndexIfAbsent(any()); - MLModelMetaCreate mlModelMetaCreate = new MLModelMetaCreate(mlIndicesHandler, threadPool, client); - MLCreateModelMetaInput mlUploadInput = prepareRequest(); - mlModelMetaCreate.createModelMeta(mlUploadInput, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); - verify(actionListener).onFailure(argumentCaptor.capture()); - } - - private MLCreateModelMetaInput prepareRequest() { - MLCreateModelMetaInput input = MLCreateModelMetaInput - .builder() - .name("Model Name") - .version("1") - .description("Custom Model Test") - .modelFormat(MLModelFormat.TORCH_SCRIPT) - .functionName(FunctionName.BATCH_RCF) - .modelContentHashValue("14555") - .modelContentSizeInBytes(1000L) - .modelConfig( - new TextEmbeddingModelConfig( - "CUSTOM", - 123, - FrameworkType.SENTENCE_TRANSFORMERS, - "all config", - TextEmbeddingModelConfig.PoolingMode.MEAN, - true, - 512 - ) - ) - .totalChunks(2) - .build(); - return input; - } -} diff --git a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportCreateModelMetaActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java similarity index 62% rename from plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportCreateModelMetaActionTests.java rename to plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java index 0e4bd61ac3..aba0f7f87d 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportCreateModelMetaActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java @@ -19,14 +19,15 @@ import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType; -import org.opensearch.ml.common.transport.upload_chunk.MLCreateModelMetaInput; -import org.opensearch.ml.common.transport.upload_chunk.MLCreateModelMetaRequest; -import org.opensearch.ml.common.transport.upload_chunk.MLCreateModelMetaResponse; +import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaInput; +import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaRequest; +import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaResponse; +import org.opensearch.ml.model.MLModelManager; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.transport.TransportService; -public class TransportCreateModelMetaActionTests extends OpenSearchTestCase { +public class TransportRegisterModelMetaActionTests extends OpenSearchTestCase { @Mock private TransportService transportService; @@ -35,10 +36,10 @@ public class TransportCreateModelMetaActionTests extends OpenSearchTestCase { private ActionFilters actionFilters; @Mock - private MLModelMetaCreate mlModelMetaCreate; + private MLModelManager mlModelManager; @Mock - private ActionListener actionListener; + private ActionListener actionListener; @Mock private Task task; @@ -50,24 +51,24 @@ public void setup() { ActionListener listener = invocation.getArgument(1); listener.onResponse("customModelId"); return null; - }).when(mlModelMetaCreate).createModelMeta(any(), any()); + }).when(mlModelManager).registerModelMeta(any(), any()); } - public void testTransportUCreateModelMetaActionConstructor() { - TransportCreateModelMetaAction action = new TransportCreateModelMetaAction(transportService, actionFilters, mlModelMetaCreate); + public void testTransportRegisterModelMetaActionConstructor() { + TransportRegisterModelMetaAction action = new TransportRegisterModelMetaAction(transportService, actionFilters, mlModelManager); assertNotNull(action); } - public void testTransportCreateModelMetaActionDoExecute() { - TransportCreateModelMetaAction action = new TransportCreateModelMetaAction(transportService, actionFilters, mlModelMetaCreate); - MLCreateModelMetaRequest actionRequest = prepareRequest(); + public void testTransportRegisterModelMetaActionDoExecute() { + TransportRegisterModelMetaAction action = new TransportRegisterModelMetaAction(transportService, actionFilters, mlModelManager); + MLRegisterModelMetaRequest actionRequest = prepareRequest(); action.doExecute(task, actionRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLCreateModelMetaResponse.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelMetaResponse.class); verify(actionListener).onResponse(argumentCaptor.capture()); } - private MLCreateModelMetaRequest prepareRequest() { - MLCreateModelMetaInput input = MLCreateModelMetaInput + private MLRegisterModelMetaRequest prepareRequest() { + MLRegisterModelMetaInput input = MLRegisterModelMetaInput .builder() .name("Model Name") .version("1") @@ -89,7 +90,7 @@ private MLCreateModelMetaRequest prepareRequest() { ) .totalChunks(2) .build(); - return new MLCreateModelMetaRequest(input); + return new MLRegisterModelMetaRequest(input); } } diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index d883313b2e..7c5c4dd225 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -88,6 +88,7 @@ import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaInput; import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.ModelHelper; import org.opensearch.ml.indices.MLIndicesHandler; @@ -151,6 +152,8 @@ public class MLModelManagerTests extends OpenSearchTestCase { ThresholdCircuitBreaker thresholdCircuitBreaker; @Mock DiscoveryNodeHelper nodeHelper; + @Mock + private ActionListener actionListener; @Before public void setup() throws URISyntaxException { @@ -763,6 +766,9 @@ private void setUpMock_DownloadModelFile(String[] chunks, Long modelContentSize) }).when(modelHelper).downloadAndSplit(any(), any(), any(), any(), any(), any()); } + @Mock + private IndexResponse indexResponse; + private String[] createTempChunkFiles() throws IOException { String tmpFolder = randomAlphaOfLength(10); String newChunk0 = chunk0.substring(0, chunk0.length() - 2) + "/" + tmpFolder + "/0"; @@ -771,4 +777,84 @@ private String[] createTempChunkFiles() throws IOException { copyFile(chunk1, newChunk1); return new String[] { newChunk0, newChunk1 }; } + + public void testRegisterModelMeta() { + setupForModelMeta(); + MLRegisterModelMetaInput registerModelMetaInput = prepareRequest(); + modelManager.registerModelMeta(registerModelMetaInput, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(String.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + } + + public void testRegisterModelMeta_FailedToInitIndex() { + setupForModelMeta(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new Exception("Init Index Failed")); + return null; + }).when(client).index(any(), any()); + MLRegisterModelMetaInput registerModelMetaInput = prepareRequest(); + modelManager.registerModelMeta(registerModelMetaInput, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + } + + public void testRegisterModelMeta_FailedToInitIndexIfPresent() { + setupForModelMeta(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(0); + actionListener.onFailure(new Exception("initModelIndexIfAbsent Failed")); + return null; + }).when(mlIndicesHandler).initModelIndexIfAbsent(any()); + MLRegisterModelMetaInput mlUploadInput = prepareRequest(); + modelManager.registerModelMeta(mlUploadInput, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + } + + private void setupForModelMeta() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(indexResponse); + return null; + }).when(client).index(any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(indexResponse); + return null; + }).when(client).index(any(), any()); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(0); + actionListener.onResponse(true); + return null; + }).when(mlIndicesHandler).initModelIndexIfAbsent(any()); + } + + private MLRegisterModelMetaInput prepareRequest() { + MLRegisterModelMetaInput input = MLRegisterModelMetaInput + .builder() + .name("Model Name") + .version("1") + .description("Custom Model Test") + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .functionName(FunctionName.BATCH_RCF) + .modelContentHashValue("14555") + .modelContentSizeInBytes(1000L) + .modelConfig( + new TextEmbeddingModelConfig( + "CUSTOM", + 123, + TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS, + "all config", + TextEmbeddingModelConfig.PoolingMode.MEAN, + true, + 512 + ) + ) + .totalChunks(2) + .build(); + return input; + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLCustomModelChunkActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCustomModelChunkActionIT.java index 0bef5ab5df..e0e3bf3062 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLCustomModelChunkActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCustomModelChunkActionIT.java @@ -18,7 +18,7 @@ import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType; -import org.opensearch.ml.common.transport.upload_chunk.MLCreateModelMetaInput; +import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaInput; import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkInput; import org.opensearch.ml.utils.TestHelper; import org.opensearch.rest.RestStatus; @@ -33,14 +33,14 @@ public void setup() { } - protected Response createModelMeta() throws IOException { + protected Response registerModelMeta() throws IOException { Response uploadCustomModelMetaResponse = TestHelper .makeRequest(client(), "POST", "_plugins/_ml/models/meta", null, TestHelper.toHttpEntity(prepareModelMeta()), null); return uploadCustomModelMetaResponse; } - public void testCreateCustomMetaModel_Success() throws IOException { - Response customModelResponse = createModelMeta(); + public void testRegisterCustomMetaModel_Success() throws IOException { + Response customModelResponse = registerModelMeta(); assertNotNull(customModelResponse); HttpEntity entity = customModelResponse.getEntity(); assertNotNull(entity); @@ -54,8 +54,8 @@ public void testCreateCustomMetaModel_Success() throws IOException { assertEquals("CREATED", getModelMap.get("status")); } - public void testCreateCustomMetaModel_PredictException() throws IOException { - Response customModelResponse = createModelMeta(); + public void testRegisterCustomMetaModel_PredictException() throws IOException { + Response customModelResponse = registerModelMeta(); assertNotNull(customModelResponse); HttpEntity entity = customModelResponse.getEntity(); String entityString = TestHelper.httpEntityToString(entity); @@ -69,7 +69,7 @@ public void testCreateCustomMetaModel_PredictException() throws IOException { public void testCustomModelWorkflow() throws IOException, InterruptedException { // register chunk - Response customModelResponse = createModelMeta(); + Response customModelResponse = registerModelMeta(); assertNotNull(customModelResponse); HttpEntity entity = customModelResponse.getEntity(); String entityString = TestHelper.httpEntityToString(entity); @@ -102,19 +102,19 @@ public void testCustomModelWorkflow() throws IOException, InterruptedException { } protected Response uploadModelChunk(final String modelId, final int chunkNumber) throws IOException { - Response createChunkUploadResponse = TestHelper + Response uploadChunkResponse = TestHelper .makeRequest( client(), "POST", - "_plugins/_ml/models/" + modelId + "/chunk/" + chunkNumber, + "_plugins/_ml/models/" + modelId + "/upload_chunk/" + chunkNumber, null, TestHelper.toHttpEntity(prepareChunkUploadInput(modelId, chunkNumber)), null ); - assertNotNull(createChunkUploadResponse); - HttpEntity entity = createChunkUploadResponse.getEntity(); + assertNotNull(uploadChunkResponse); + HttpEntity entity = uploadChunkResponse.getEntity(); assertNotNull(entity); - return createChunkUploadResponse; + return uploadChunkResponse; } private String prepareModelMeta() throws IOException { @@ -125,7 +125,7 @@ private String prepareModelMeta() throws IOException { .frameworkType(FrameworkType.SENTENCE_TRANSFORMERS) .modelType("bert") .build(); - MLCreateModelMetaInput input = MLCreateModelMetaInput + MLRegisterModelMetaInput input = MLRegisterModelMetaInput .builder() .name("test_model") .version("1") diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateModelMetaActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelMetaActionTests.java similarity index 67% rename from plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateModelMetaActionTests.java rename to plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelMetaActionTests.java index 4381c3879b..b6590f8f47 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateModelMetaActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelMetaActionTests.java @@ -15,6 +15,7 @@ import java.util.Map; import org.junit.Before; +import org.junit.Ignore; import org.junit.Rule; import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; @@ -27,9 +28,9 @@ import org.opensearch.common.xcontent.NamedXContentRegistry; import org.opensearch.common.xcontent.XContentType; import org.opensearch.ml.common.transport.model.MLModelGetResponse; -import org.opensearch.ml.common.transport.upload_chunk.MLCreateModelMetaAction; -import org.opensearch.ml.common.transport.upload_chunk.MLCreateModelMetaInput; -import org.opensearch.ml.common.transport.upload_chunk.MLCreateModelMetaRequest; +import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaAction; +import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaInput; +import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaRequest; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -40,9 +41,9 @@ import com.google.gson.Gson; -public class RestMLCreateModelMetaActionTests extends OpenSearchTestCase { +public class RestMLRegisterModelMetaActionTests extends OpenSearchTestCase { - private RestMLCreateModelMetaAction restMLCreateModelMetaAction; + private RestMLRegisterModelMetaAction restMLRegisterModelMetaAction; private NodeClient client; private ThreadPool threadPool; @@ -54,13 +55,13 @@ public class RestMLCreateModelMetaActionTests extends OpenSearchTestCase { @Before public void setup() { - restMLCreateModelMetaAction = new RestMLCreateModelMetaAction(); + restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction(); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); return null; - }).when(client).execute(eq(MLCreateModelMetaAction.INSTANCE), any(), any()); + }).when(client).execute(eq(MLRegisterModelMetaAction.INSTANCE), any(), any()); } @Override @@ -71,43 +72,53 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLCreateModelMetaAction mlUploadModel = new RestMLCreateModelMetaAction(); + RestMLRegisterModelMetaAction mlUploadModel = new RestMLRegisterModelMetaAction(); assertNotNull(mlUploadModel); } public void testGetName() { - String actionName = restMLCreateModelMetaAction.getName(); + String actionName = restMLRegisterModelMetaAction.getName(); assertFalse(Strings.isNullOrEmpty(actionName)); - assertEquals("ml_create_model_meta_action", actionName); + assertEquals("ml_register_model_meta_action", actionName); } + @Ignore public void testRoutes() { - List routes = restMLCreateModelMetaAction.routes(); + List routes = restMLRegisterModelMetaAction.routes(); assertNotNull(routes); assertFalse(routes.isEmpty()); RestHandler.Route route = routes.get(0); assertEquals(RestRequest.Method.POST, route.getMethod()); - assertEquals("/_plugins/_ml/models/meta", route.getPath()); + assertEquals("/_plugins/_ml/models/_register_meta", route.getPath()); } - public void testCreateModelMetaRequest() throws Exception { + public void testReplacedRoutes() { + List replacedRoutes = restMLRegisterModelMetaAction.replacedRoutes(); + assertNotNull(replacedRoutes); + assertFalse(replacedRoutes.isEmpty()); + RestHandler.Route route = replacedRoutes.get(0); + assertEquals(RestRequest.Method.POST, route.getMethod()); + assertEquals("/_plugins/_ml/models/_register_meta", route.getPath()); + } + + public void testRegisterModelMetaRequest() throws Exception { RestRequest request = getRestRequest(); - restMLCreateModelMetaAction.handleRequest(request, channel, client); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLCreateModelMetaRequest.class); - verify(client, times(1)).execute(eq(MLCreateModelMetaAction.INSTANCE), argumentCaptor.capture(), any()); - MLCreateModelMetaInput metaModelRequest = argumentCaptor.getValue().getMlCreateModelMetaInput(); + restMLRegisterModelMetaAction.handleRequest(request, channel, client); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelMetaRequest.class); + verify(client, times(1)).execute(eq(MLRegisterModelMetaAction.INSTANCE), argumentCaptor.capture(), any()); + MLRegisterModelMetaInput metaModelRequest = argumentCaptor.getValue().getMlRegisterModelMetaInput(); assertEquals("all-MiniLM-L6-v3", metaModelRequest.getName()); assertEquals("1", metaModelRequest.getVersion()); assertEquals(Integer.valueOf(2), metaModelRequest.getTotalChunks()); } - public void testCreateModelMeta_NoContent() throws Exception { + public void testRegisterModelMeta_NoContent() throws Exception { RestRequest.Method method = RestRequest.Method.POST; Map params = new HashMap<>(); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withMethod(method).withParams(params).build(); expectedEx.expect(IOException.class); expectedEx.expectMessage("Model meta request has empty body"); - restMLCreateModelMetaAction.handleRequest(request, channel, client); + restMLRegisterModelMetaAction.handleRequest(request, channel, client); } private RestRequest getRestRequest() { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUploadModelChunkActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUploadModelChunkActionTests.java index 657ff1a04e..e52c4b22ff 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUploadModelChunkActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUploadModelChunkActionTests.java @@ -14,6 +14,7 @@ import java.util.Map; import org.junit.Before; +import org.junit.Ignore; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.opensearch.action.ActionListener; @@ -72,13 +73,23 @@ public void testGetName() { assertEquals("ml_upload_model_chunk_action", actionName); } + @Ignore public void testRoutes() { List routes = restChunkUploadAction.routes(); assertNotNull(routes); assertFalse(routes.isEmpty()); RestHandler.Route route = routes.get(0); assertEquals(RestRequest.Method.POST, route.getMethod()); - assertEquals("/_plugins/_ml/models/{model_id}/chunk/{chunk_number}", route.getPath()); + assertEquals("/_plugins/_ml/models/{model_id}/upload_chunk/{chunk_number}", route.getPath()); + } + + public void testReplacedRoutes() { + List replacedRoutes = restChunkUploadAction.replacedRoutes(); + assertNotNull(replacedRoutes); + assertFalse(replacedRoutes.isEmpty()); + RestHandler.Route route = replacedRoutes.get(0); + assertEquals(RestRequest.Method.POST, route.getMethod()); + assertEquals("/_plugins/_ml/models/{model_id}/upload_chunk/{chunk_number}", route.getPath()); } public void testUploadChunkRequest() throws Exception {