Skip to content

Commit

Permalink
addressed comments
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Ramaram <[email protected]>
  • Loading branch information
rbhavna committed Mar 19, 2024
1 parent 2943726 commit fb127f2
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 74 deletions.
2 changes: 0 additions & 2 deletions common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ public class MLModel implements ToXContentObject {
@Deprecated
public static final String ALGORITHM_FIELD = "algorithm";
public static final String FUNCTION_NAME_FIELD = "function_name";
public static final String MODEL_TASK_TYPE_FIELD = "model_task_type";
public static final String MODEL_NAME_FIELD = "name";
public static final String MODEL_GROUP_ID_FIELD = "model_group_id";
// We use int type for version in first release 1.3. In 2.4, we changed to
Expand Down Expand Up @@ -503,7 +502,6 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
case USER:
user = User.parse(parser);
break;
case MODEL_TASK_TYPE_FIELD:
case ALGORITHM_FIELD:
case FUNCTION_NAME_FIELD:
algorithm = FunctionName.from(parser.text().toUpperCase(Locale.ROOT));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,20 @@ public class QuestionAnsweringModelConfig extends MLModelConfig {
it -> parse(it)
);
public static final String FRAMEWORK_TYPE_FIELD = "framework_type";
public static final String POOLING_MODE_FIELD = "pooling_mode";
public static final String NORMALIZE_RESULT_FIELD = "normalize_result";
public static final String MODEL_MAX_LENGTH_FIELD = "model_max_length";

private final FrameworkType frameworkType;
private final PoolingMode poolingMode;
private final boolean normalizeResult;
private final Integer modelMaxLength;

@Builder(toBuilder = true)
public QuestionAnsweringModelConfig(String modelType, FrameworkType frameworkType, String allConfig,
PoolingMode poolingMode, boolean normalizeResult, Integer modelMaxLength) {
public QuestionAnsweringModelConfig(String modelType, FrameworkType frameworkType, String allConfig, boolean normalizeResult, Integer modelMaxLength) {
super(modelType, allConfig);
if (frameworkType == null) {
throw new IllegalArgumentException("framework type is null");
}
this.frameworkType = frameworkType;
this.poolingMode = poolingMode;
this.normalizeResult = normalizeResult;
this.modelMaxLength = modelMaxLength;
}
Expand All @@ -57,7 +53,6 @@ public static QuestionAnsweringModelConfig parse(XContentParser parser) throws I
String modelType = null;
FrameworkType frameworkType = null;
String allConfig = null;
PoolingMode poolingMode = null;
boolean normalizeResult = false;
Integer modelMaxLength = null;

Expand All @@ -76,9 +71,6 @@ public static QuestionAnsweringModelConfig parse(XContentParser parser) throws I
case ALL_CONFIG_FIELD:
allConfig = parser.text();
break;
case POOLING_MODE_FIELD:
poolingMode = PoolingMode.from(parser.text().toUpperCase(Locale.ROOT));
break;
case NORMALIZE_RESULT_FIELD:
normalizeResult = parser.booleanValue();
break;
Expand All @@ -90,7 +82,7 @@ public static QuestionAnsweringModelConfig parse(XContentParser parser) throws I
break;
}
}
return new QuestionAnsweringModelConfig(modelType, frameworkType, allConfig, poolingMode, normalizeResult, modelMaxLength);
return new QuestionAnsweringModelConfig(modelType, frameworkType, allConfig, normalizeResult, modelMaxLength);
}

@Override
Expand All @@ -101,11 +93,6 @@ public String getWriteableName() {
public QuestionAnsweringModelConfig(StreamInput in) throws IOException{
super(in);
frameworkType = in.readEnum(FrameworkType.class);
if (in.readBoolean()) {
poolingMode = in.readEnum(PoolingMode.class);
} else {
poolingMode = null;
}
normalizeResult = in.readBoolean();
modelMaxLength = in.readOptionalInt();
}
Expand All @@ -114,12 +101,6 @@ public QuestionAnsweringModelConfig(StreamInput in) throws IOException{
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeEnum(frameworkType);
if (poolingMode != null) {
out.writeBoolean(true);
out.writeEnum(poolingMode);
} else {
out.writeBoolean(false);
}
out.writeBoolean(normalizeResult);
out.writeOptionalInt(modelMaxLength);
}
Expand All @@ -139,41 +120,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (modelMaxLength != null) {
builder.field(MODEL_MAX_LENGTH_FIELD, modelMaxLength);
}
if (poolingMode != null) {
builder.field(POOLING_MODE_FIELD, poolingMode);
}
if (normalizeResult) {
builder.field(NORMALIZE_RESULT_FIELD, normalizeResult);
}
builder.endObject();
return builder;
}

public enum PoolingMode {
MEAN("mean"),
MEAN_SQRT_LEN("mean_sqrt_len"),
MAX("max"),
WEIGHTED_MEAN("weightedmean"),
CLS("cls"),
LAST_TOKEN("lasttoken");

private String name;

public String getName() {
return name;
}
PoolingMode(String name) {
this.name = name;
}

public static PoolingMode from(String value) {
try {
return PoolingMode.valueOf(value.toUpperCase(Locale.ROOT));
} catch (Exception e) {
throw new IllegalArgumentException("Wrong pooling method");
}
}
}
public enum FrameworkType {
HUGGINGFACE_TRANSFORMERS,
SENTENCE_TRANSFORMERS,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@
public class MLRegisterModelInput implements ToXContentObject, Writeable {

public static final String FUNCTION_NAME_FIELD = "function_name";
public static final String MODEL_TASK_TYPE_FIELD = "model_task_type";

public static final String NAME_FIELD = "name";
public static final String MODEL_GROUP_ID_FIELD = "model_group_id";
public static final String DESCRIPTION_FIELD = "description";
Expand Down Expand Up @@ -362,7 +360,6 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName
String fieldName = parser.currentName();
parser.nextToken();
switch (fieldName) {
case MODEL_TASK_TYPE_FIELD:
case FUNCTION_NAME_FIELD:
functionName = FunctionName.from(parser.text().toUpperCase(Locale.ROOT));
break;
Expand Down Expand Up @@ -466,7 +463,6 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo
parser.nextToken();

switch (fieldName) {
case MODEL_TASK_TYPE_FIELD:
case FUNCTION_NAME_FIELD:
functionName = FunctionName.from(parser.text().toUpperCase(Locale.ROOT));
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.opensearch.ml.common.controller.MLRateLimiter;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.model.MetricsCorrelationModelConfig;
import org.opensearch.ml.common.model.QuestionAnsweringModelConfig;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;

Expand All @@ -37,7 +36,6 @@
public class MLRegisterModelMetaInput implements ToXContentObject, Writeable {

public static final String FUNCTION_NAME_FIELD = "function_name";
public static final String MODEL_TASK_TYPE_FIELD = "model_task_type";
public static final String MODEL_NAME_FIELD = "name"; // mandatory
public static final String DESCRIPTION_FIELD = "description"; // optional
public static final String IS_ENABLED_FIELD = "is_enabled"; // optional
Expand Down Expand Up @@ -305,7 +303,6 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc
case MODEL_NAME_FIELD:
name = parser.text();
break;
case MODEL_TASK_TYPE_FIELD:
case FUNCTION_NAME_FIELD:
functionName = FunctionName.from(parser.text());
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public void setUp() {
config = QuestionAnsweringModelConfig.builder()
.modelType("testModelType")
.allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}")
.normalizeResult(false)
.frameworkType(QuestionAnsweringModelConfig.FrameworkType.SENTENCE_TRANSFORMERS)
.build();
function = parser -> {
Expand Down Expand Up @@ -72,7 +73,7 @@ public void nullFields_FrameworkType() {

@Test
public void parse() throws IOException {
String content = "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}";
String content = "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"normalize_result\":false,\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}";
TestHelper.testParseFromString(config, content, function);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ public void downloadPrebuiltModelConfig(

MLRegisterModelInput.MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder();

String functionName = config.containsKey("function_name")
? (String) config.get("function_name")
: (String) config.get("model_task_type");

builder
.modelName(modelName)
.version(version)
Expand All @@ -100,7 +104,7 @@ public void downloadPrebuiltModelConfig(
.modelNodeIds(modelNodeIds)
.isHidden(isHidden)
.modelGroupId(modelGroupId)
.functionName(FunctionName.from((String) config.get("model_task_type")));
.functionName(FunctionName.from((functionName)));

config.entrySet().forEach(entry -> {
switch (entry.getKey().toString()) {
Expand All @@ -126,19 +130,6 @@ public void downloadPrebuiltModelConfig(
QuestionAnsweringModelConfig.FrameworkType.from(configEntry.getValue().toString())
);
break;
case QuestionAnsweringModelConfig.POOLING_MODE_FIELD:
configBuilder
.poolingMode(
QuestionAnsweringModelConfig.PoolingMode
.from(configEntry.getValue().toString().toUpperCase(Locale.ROOT))
);
break;
case QuestionAnsweringModelConfig.NORMALIZE_RESULT_FIELD:
configBuilder.normalizeResult(Boolean.parseBoolean(configEntry.getValue().toString()));
break;
case QuestionAnsweringModelConfig.MODEL_MAX_LENGTH_FIELD:
configBuilder.modelMaxLength(((Double) configEntry.getValue()).intValue());
break;
default:
break;
}
Expand Down

0 comments on commit fb127f2

Please sign in to comment.