From fb127f2d8895a9f4278b752cf356332b5a0fdffd Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Tue, 19 Mar 2024 01:50:13 -0500 Subject: [PATCH] addressed comments Signed-off-by: Bhavana Ramaram --- .../org/opensearch/ml/common/MLModel.java | 2 - .../model/QuestionAnsweringModelConfig.java | 52 +------------------ .../register/MLRegisterModelInput.java | 4 -- .../MLRegisterModelMetaInput.java | 3 -- .../QuestionAnsweringModelConfigTests.java | 3 +- .../org/opensearch/ml/engine/ModelHelper.java | 19 ++----- 6 files changed, 9 insertions(+), 74 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/MLModel.java b/common/src/main/java/org/opensearch/ml/common/MLModel.java index a114fe305c..479cd09a73 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModel.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModel.java @@ -40,7 +40,6 @@ public class MLModel implements ToXContentObject { @Deprecated public static final String ALGORITHM_FIELD = "algorithm"; public static final String FUNCTION_NAME_FIELD = "function_name"; - public static final String MODEL_TASK_TYPE_FIELD = "model_task_type"; public static final String MODEL_NAME_FIELD = "name"; public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; // We use int type for version in first release 1.3. In 2.4, we changed to @@ -503,7 +502,6 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws case USER: user = User.parse(parser); break; - case MODEL_TASK_TYPE_FIELD: case ALGORITHM_FIELD: case FUNCTION_NAME_FIELD: algorithm = FunctionName.from(parser.text().toUpperCase(Locale.ROOT)); diff --git a/common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java b/common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java index 6c94636691..7b01f847a2 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java +++ b/common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java @@ -31,24 +31,20 @@ public class QuestionAnsweringModelConfig extends MLModelConfig { it -> parse(it) ); public static final String FRAMEWORK_TYPE_FIELD = "framework_type"; - public static final String POOLING_MODE_FIELD = "pooling_mode"; public static final String NORMALIZE_RESULT_FIELD = "normalize_result"; public static final String MODEL_MAX_LENGTH_FIELD = "model_max_length"; private final FrameworkType frameworkType; - private final PoolingMode poolingMode; private final boolean normalizeResult; private final Integer modelMaxLength; @Builder(toBuilder = true) - public QuestionAnsweringModelConfig(String modelType, FrameworkType frameworkType, String allConfig, - PoolingMode poolingMode, boolean normalizeResult, Integer modelMaxLength) { + public QuestionAnsweringModelConfig(String modelType, FrameworkType frameworkType, String allConfig, boolean normalizeResult, Integer modelMaxLength) { super(modelType, allConfig); if (frameworkType == null) { throw new IllegalArgumentException("framework type is null"); } this.frameworkType = frameworkType; - this.poolingMode = poolingMode; this.normalizeResult = normalizeResult; this.modelMaxLength = modelMaxLength; } @@ -57,7 +53,6 @@ public static QuestionAnsweringModelConfig parse(XContentParser parser) throws I String modelType = null; FrameworkType frameworkType = null; String allConfig = null; - PoolingMode poolingMode = null; boolean normalizeResult = false; Integer modelMaxLength = null; @@ -76,9 +71,6 @@ public static QuestionAnsweringModelConfig parse(XContentParser parser) throws I case ALL_CONFIG_FIELD: allConfig = parser.text(); break; - case POOLING_MODE_FIELD: - poolingMode = PoolingMode.from(parser.text().toUpperCase(Locale.ROOT)); - break; case NORMALIZE_RESULT_FIELD: normalizeResult = parser.booleanValue(); break; @@ -90,7 +82,7 @@ public static QuestionAnsweringModelConfig parse(XContentParser parser) throws I break; } } - return new QuestionAnsweringModelConfig(modelType, frameworkType, allConfig, poolingMode, normalizeResult, modelMaxLength); + return new QuestionAnsweringModelConfig(modelType, frameworkType, allConfig, normalizeResult, modelMaxLength); } @Override @@ -101,11 +93,6 @@ public String getWriteableName() { public QuestionAnsweringModelConfig(StreamInput in) throws IOException{ super(in); frameworkType = in.readEnum(FrameworkType.class); - if (in.readBoolean()) { - poolingMode = in.readEnum(PoolingMode.class); - } else { - poolingMode = null; - } normalizeResult = in.readBoolean(); modelMaxLength = in.readOptionalInt(); } @@ -114,12 +101,6 @@ public QuestionAnsweringModelConfig(StreamInput in) throws IOException{ public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeEnum(frameworkType); - if (poolingMode != null) { - out.writeBoolean(true); - out.writeEnum(poolingMode); - } else { - out.writeBoolean(false); - } out.writeBoolean(normalizeResult); out.writeOptionalInt(modelMaxLength); } @@ -139,41 +120,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (modelMaxLength != null) { builder.field(MODEL_MAX_LENGTH_FIELD, modelMaxLength); } - if (poolingMode != null) { - builder.field(POOLING_MODE_FIELD, poolingMode); - } if (normalizeResult) { builder.field(NORMALIZE_RESULT_FIELD, normalizeResult); } builder.endObject(); return builder; } - - public enum PoolingMode { - MEAN("mean"), - MEAN_SQRT_LEN("mean_sqrt_len"), - MAX("max"), - WEIGHTED_MEAN("weightedmean"), - CLS("cls"), - LAST_TOKEN("lasttoken"); - - private String name; - - public String getName() { - return name; - } - PoolingMode(String name) { - this.name = name; - } - - public static PoolingMode from(String value) { - try { - return PoolingMode.valueOf(value.toUpperCase(Locale.ROOT)); - } catch (Exception e) { - throw new IllegalArgumentException("Wrong pooling method"); - } - } - } public enum FrameworkType { HUGGINGFACE_TRANSFORMERS, SENTENCE_TRANSFORMERS, diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java index 39763491b1..4b3b3cfb0f 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java @@ -42,8 +42,6 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { public static final String FUNCTION_NAME_FIELD = "function_name"; - public static final String MODEL_TASK_TYPE_FIELD = "model_task_type"; - public static final String NAME_FIELD = "name"; public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; public static final String DESCRIPTION_FIELD = "description"; @@ -362,7 +360,6 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName String fieldName = parser.currentName(); parser.nextToken(); switch (fieldName) { - case MODEL_TASK_TYPE_FIELD: case FUNCTION_NAME_FIELD: functionName = FunctionName.from(parser.text().toUpperCase(Locale.ROOT)); break; @@ -466,7 +463,6 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo parser.nextToken(); switch (fieldName) { - case MODEL_TASK_TYPE_FIELD: case FUNCTION_NAME_FIELD: functionName = FunctionName.from(parser.text().toUpperCase(Locale.ROOT)); break; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java index e1df0d8c96..e7ab3b7091 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java @@ -22,7 +22,6 @@ import org.opensearch.ml.common.controller.MLRateLimiter; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; -import org.opensearch.ml.common.model.MetricsCorrelationModelConfig; import org.opensearch.ml.common.model.QuestionAnsweringModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; @@ -37,7 +36,6 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable { public static final String FUNCTION_NAME_FIELD = "function_name"; - public static final String MODEL_TASK_TYPE_FIELD = "model_task_type"; public static final String MODEL_NAME_FIELD = "name"; // mandatory public static final String DESCRIPTION_FIELD = "description"; // optional public static final String IS_ENABLED_FIELD = "is_enabled"; // optional @@ -305,7 +303,6 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc case MODEL_NAME_FIELD: name = parser.text(); break; - case MODEL_TASK_TYPE_FIELD: case FUNCTION_NAME_FIELD: functionName = FunctionName.from(parser.text()); break; diff --git a/common/src/test/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfigTests.java b/common/src/test/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfigTests.java index a51fbb0e58..5136c187b7 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfigTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfigTests.java @@ -34,6 +34,7 @@ public void setUp() { config = QuestionAnsweringModelConfig.builder() .modelType("testModelType") .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .normalizeResult(false) .frameworkType(QuestionAnsweringModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) .build(); function = parser -> { @@ -72,7 +73,7 @@ public void nullFields_FrameworkType() { @Test public void parse() throws IOException { - String content = "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}"; + String content = "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"normalize_result\":false,\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}"; TestHelper.testParseFromString(config, content, function); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java index aa9b25542b..50c514599e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java @@ -92,6 +92,10 @@ public void downloadPrebuiltModelConfig( MLRegisterModelInput.MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder(); + String functionName = config.containsKey("function_name") + ? (String) config.get("function_name") + : (String) config.get("model_task_type"); + builder .modelName(modelName) .version(version) @@ -100,7 +104,7 @@ public void downloadPrebuiltModelConfig( .modelNodeIds(modelNodeIds) .isHidden(isHidden) .modelGroupId(modelGroupId) - .functionName(FunctionName.from((String) config.get("model_task_type"))); + .functionName(FunctionName.from((functionName))); config.entrySet().forEach(entry -> { switch (entry.getKey().toString()) { @@ -126,19 +130,6 @@ public void downloadPrebuiltModelConfig( QuestionAnsweringModelConfig.FrameworkType.from(configEntry.getValue().toString()) ); break; - case QuestionAnsweringModelConfig.POOLING_MODE_FIELD: - configBuilder - .poolingMode( - QuestionAnsweringModelConfig.PoolingMode - .from(configEntry.getValue().toString().toUpperCase(Locale.ROOT)) - ); - break; - case QuestionAnsweringModelConfig.NORMALIZE_RESULT_FIELD: - configBuilder.normalizeResult(Boolean.parseBoolean(configEntry.getValue().toString())); - break; - case QuestionAnsweringModelConfig.MODEL_MAX_LENGTH_FIELD: - configBuilder.modelMaxLength(((Double) configEntry.getValue()).intValue()); - break; default: break; }