Skip to content

Commit

Permalink
rename model meta/chunk API (#827)
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn authored Mar 27, 2023
1 parent 158d445 commit bf662e3
Show file tree
Hide file tree
Showing 25 changed files with 470 additions and 551 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.upload_chunk;

import org.opensearch.action.ActionType;

public class MLRegisterModelMetaAction extends ActionType<MLRegisterModelMetaResponse> {
public static MLRegisterModelMetaAction INSTANCE = new MLRegisterModelMetaAction();
public static final String NAME = "cluster:admin/opensearch/ml/register_model_meta";

private MLRegisterModelMetaAction() {
super(NAME, MLRegisterModelMetaResponse::new);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;

@Data
public class MLCreateModelMetaInput implements ToXContentObject, Writeable{
public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{

public static final String FUNCTION_NAME_FIELD = "function_name";
public static final String MODEL_NAME_FIELD = "name"; //mandatory
Expand Down Expand Up @@ -52,7 +52,7 @@ public class MLCreateModelMetaInput implements ToXContentObject, Writeable{
private Integer totalChunks;

@Builder(toBuilder = true)
public MLCreateModelMetaInput(String name, FunctionName functionName, String version, String description, MLModelFormat modelFormat, MLModelState modelState, Long modelContentSizeInBytes, String modelContentHashValue, MLModelConfig modelConfig, Integer totalChunks) {
public MLRegisterModelMetaInput(String name, FunctionName functionName, String version, String description, MLModelFormat modelFormat, MLModelState modelState, Long modelContentSizeInBytes, String modelContentHashValue, MLModelConfig modelConfig, Integer totalChunks) {
if (name == null) {
throw new IllegalArgumentException("model name is null");
}
Expand Down Expand Up @@ -87,7 +87,7 @@ public MLCreateModelMetaInput(String name, FunctionName functionName, String ver
this.totalChunks = totalChunks;
}

public MLCreateModelMetaInput(StreamInput in) throws IOException{
public MLRegisterModelMetaInput(StreamInput in) throws IOException{
this.name = in.readString();
this.functionName = in.readEnum(FunctionName.class);
this.version = in.readString();
Expand Down Expand Up @@ -158,7 +158,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
return builder;
}

public static MLCreateModelMetaInput parse(XContentParser parser) throws IOException {
public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOException {
String name = null;
FunctionName functionName = null;
String version = null;
Expand Down Expand Up @@ -210,7 +210,7 @@ public static MLCreateModelMetaInput parse(XContentParser parser) throws IOExcep
break;
}
}
return new MLCreateModelMetaInput(name, functionName, version, description, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks);
return new MLRegisterModelMetaInput(name, functionName, version, description, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,24 @@
@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@ToString
public class MLCreateModelMetaRequest extends ActionRequest {
public class MLRegisterModelMetaRequest extends ActionRequest {

MLCreateModelMetaInput mlCreateModelMetaInput;
MLRegisterModelMetaInput mlRegisterModelMetaInput;

@Builder
public MLCreateModelMetaRequest(MLCreateModelMetaInput mlCreateModelMetaInput) {
this.mlCreateModelMetaInput = mlCreateModelMetaInput;
public MLRegisterModelMetaRequest(MLRegisterModelMetaInput mlRegisterModelMetaInput) {
this.mlRegisterModelMetaInput = mlRegisterModelMetaInput;
}

public MLCreateModelMetaRequest(StreamInput in) throws IOException {
public MLRegisterModelMetaRequest(StreamInput in) throws IOException {
super(in);
this.mlCreateModelMetaInput = new MLCreateModelMetaInput(in);
this.mlRegisterModelMetaInput = new MLRegisterModelMetaInput(in);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;
if (mlCreateModelMetaInput == null) {
if (mlRegisterModelMetaInput == null) {
exception = addValidationError("Model meta input can't be null", exception);
}

Expand All @@ -54,22 +54,22 @@ public ActionRequestValidationException validate() {
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
this.mlCreateModelMetaInput.writeTo(out);
this.mlRegisterModelMetaInput.writeTo(out);
}

public static MLCreateModelMetaRequest fromActionRequest(ActionRequest actionRequest) {
if (actionRequest instanceof MLCreateModelMetaRequest) {
return (MLCreateModelMetaRequest) actionRequest;
public static MLRegisterModelMetaRequest fromActionRequest(ActionRequest actionRequest) {
if (actionRequest instanceof MLRegisterModelMetaRequest) {
return (MLRegisterModelMetaRequest) actionRequest;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLCreateModelMetaRequest(input);
return new MLRegisterModelMetaRequest(input);
}
} catch (IOException e) {
throw new UncheckedIOException("Failed to parse ActionRequest into MLCreateModelMetaRequest", e);
throw new UncheckedIOException("Failed to parse ActionRequest into MLRegisterModelMetaRequest", e);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.common.transport.upload_chunk;

import lombok.Getter;
import org.opensearch.action.ActionResponse;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
Expand All @@ -13,21 +14,23 @@

import java.io.IOException;

public class MLCreateModelMetaResponse extends ActionResponse implements ToXContentObject {
public class MLRegisterModelMetaResponse extends ActionResponse implements ToXContentObject {

public static final String MODEL_ID_FIELD = "model_id";
public static final String STATUS_FIELD = "status";

@Getter
private String modelId;
@Getter
private String status;

public MLCreateModelMetaResponse(StreamInput in) throws IOException {
public MLRegisterModelMetaResponse(StreamInput in) throws IOException {
super(in);
this.modelId = in.readString();
this.status = in.readString();
}

public MLCreateModelMetaResponse(String modelId, String status) {
public MLRegisterModelMetaResponse(String modelId, String status) {
this.modelId = modelId;
this.status= status;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.common.transport.upload_chunk;

import lombok.Getter;
import org.opensearch.action.ActionResponse;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
Expand All @@ -15,6 +16,7 @@

public class MLUploadModelChunkResponse extends ActionResponse implements ToXContentObject {
public static final String STATUS_FIELD = "status";
@Getter
private String status;

public MLUploadModelChunkResponse (StreamInput in) throws IOException {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -25,43 +25,43 @@
import static org.junit.Assert.assertEquals;
import static org.opensearch.common.xcontent.ToXContent.EMPTY_PARAMS;

public class MLCreateModelMetaInputTest {
public class MLRegisterModelMetaInputTest {


Function<XContentParser, MLCreateModelMetaInput> function = parser -> {
Function<XContentParser, MLRegisterModelMetaInput> function = parser -> {
try {
return MLCreateModelMetaInput.parse(parser);
return MLRegisterModelMetaInput.parse(parser);
} catch (Exception e) {
throw new RuntimeException("Failed to parse MLCreateModelMetaInput", e);
throw new RuntimeException("Failed to parse MLRegisterModelMetaInput", e);
}
};
TextEmbeddingModelConfig config;
MLCreateModelMetaInput mLCreateModelMetaInput;
MLRegisterModelMetaInput mLRegisterModelMetaInput;

@Before
public void setup() {
config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config",
TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512);
mLCreateModelMetaInput = new MLCreateModelMetaInput("Model Name", FunctionName.BATCH_RCF, "1.0",
mLRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "1.0",
"Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2);
}

@Test
public void parse_MLCreateModelMetaInput() throws IOException {
TestHelper.testParse(mLCreateModelMetaInput, function);
public void parse_MLRegisterModelMetaInput() throws IOException {
TestHelper.testParse(mLRegisterModelMetaInput, function);
}

@Test
public void readInputStream_Success() throws IOException {
readInputStream(mLCreateModelMetaInput);
readInputStream(mLRegisterModelMetaInput);
}


private void readInputStream(MLCreateModelMetaInput input) throws IOException {
private void readInputStream(MLRegisterModelMetaInput input) throws IOException {
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
input.writeTo(bytesStreamOutput);
StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
MLCreateModelMetaInput newInput = new MLCreateModelMetaInput(streamInput);
MLRegisterModelMetaInput newInput = new MLRegisterModelMetaInput(streamInput);
assertEquals(input.getName(), newInput.getName());
assertEquals(input.getDescription(), newInput.getDescription());
assertEquals(input.getModelFormat(), newInput.getModelFormat());
Expand All @@ -73,7 +73,7 @@ private void readInputStream(MLCreateModelMetaInput input) throws IOException {
@Test
public void testToXContent() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mLCreateModelMetaInput.toXContent(builder, EMPTY_PARAMS);
mLRegisterModelMetaInput.toXContent(builder, EMPTY_PARAMS);
String mlModelContent = TestHelper.xContentBuilderToString(builder);
final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"version\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," +
"\"embedding_dimension\":123,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\",\"model_max_length\":512,\"pooling_mode\":\"MEAN\",\"normalize_result\":true},\"total_chunks\":2}";
Expand Down
Loading

0 comments on commit bf662e3

Please sign in to comment.