Skip to content

Commit

Permalink
support model_task_type and qa_model_config in ml input
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Ramaram <[email protected]>
  • Loading branch information
rbhavna committed Mar 19, 2024
1 parent 01c1649 commit 4de3643
Show file tree
Hide file tree
Showing 6 changed files with 412 additions and 43 deletions.
7 changes: 7 additions & 0 deletions common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.ml.common.controller.MLRateLimiter;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.model.QuestionAnsweringModelConfig;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.model.MetricsCorrelationModelConfig;

Expand All @@ -38,6 +39,7 @@ public class MLModel implements ToXContentObject {
@Deprecated
public static final String ALGORITHM_FIELD = "algorithm";
public static final String FUNCTION_NAME_FIELD = "function_name";
public static final String MODEL_TASK_TYPE_FIELD = "model_task_type";
public static final String MODEL_NAME_FIELD = "name";
public static final String MODEL_GROUP_ID_FIELD = "model_group_id";
// We use int type for version in first release 1.3. In 2.4, we changed to
Expand Down Expand Up @@ -215,6 +217,8 @@ public MLModel(StreamInput input) throws IOException {
if (input.readBoolean()) {
if (algorithm.equals(FunctionName.METRICS_CORRELATION)) {
modelConfig = new MetricsCorrelationModelConfig(input);
} else if (algorithm.equals(FunctionName.QUESTION_ANSWERING)) {
modelConfig = new QuestionAnsweringModelConfig(input);

Check warning on line 221 in common/src/main/java/org/opensearch/ml/common/MLModel.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/MLModel.java#L221

Added line #L221 was not covered by tests
} else {
modelConfig = new TextEmbeddingModelConfig(input);
}
Expand Down Expand Up @@ -482,6 +486,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
case USER:
user = User.parse(parser);
break;
case MODEL_TASK_TYPE_FIELD:
case ALGORITHM_FIELD:
case FUNCTION_NAME_FIELD:
algorithm = FunctionName.from(parser.text().toUpperCase(Locale.ROOT));
Expand Down Expand Up @@ -510,6 +515,8 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
case MODEL_CONFIG_FIELD:
if (FunctionName.METRICS_CORRELATION.name().equals(algorithmName)) {
modelConfig = MetricsCorrelationModelConfig.parse(parser);
} else if (FunctionName.QUESTION_ANSWERING.name().equals(algorithmName)) {
modelConfig = QuestionAnsweringModelConfig.parse(parser);

Check warning on line 519 in common/src/main/java/org/opensearch/ml/common/MLModel.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/MLModel.java#L519

Added line #L519 was not covered by tests
} else {
modelConfig = TextEmbeddingModelConfig.parse(parser);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.model;

import lombok.Builder;
import lombok.Getter;
import lombok.Setter;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;

import java.io.IOException;
import java.util.Locale;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

@Setter
@Getter
public class QuestionAnsweringModelConfig extends MLModelConfig {
public static final String PARSE_FIELD_NAME = FunctionName.QUESTION_ANSWERING.name();
public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry(
QuestionAnsweringModelConfig.class,
new ParseField(PARSE_FIELD_NAME),
it -> parse(it)

Check warning on line 31 in common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java#L31

Added line #L31 was not covered by tests
);
public static final String FRAMEWORK_TYPE_FIELD = "framework_type";
public static final String POOLING_MODE_FIELD = "pooling_mode";
public static final String NORMALIZE_RESULT_FIELD = "normalize_result";
public static final String MODEL_MAX_LENGTH_FIELD = "model_max_length";

private final FrameworkType frameworkType;
private final PoolingMode poolingMode;
private final boolean normalizeResult;
private final Integer modelMaxLength;

@Builder(toBuilder = true)
public QuestionAnsweringModelConfig(String modelType, FrameworkType frameworkType, String allConfig,
PoolingMode poolingMode, boolean normalizeResult, Integer modelMaxLength) {
super(modelType, allConfig);
if (frameworkType == null) {
throw new IllegalArgumentException("framework type is null");
}
this.frameworkType = frameworkType;
this.poolingMode = poolingMode;
this.normalizeResult = normalizeResult;
this.modelMaxLength = modelMaxLength;
}

public static QuestionAnsweringModelConfig parse(XContentParser parser) throws IOException {
String modelType = null;
FrameworkType frameworkType = null;
String allConfig = null;
PoolingMode poolingMode = null;
boolean normalizeResult = false;
Integer modelMaxLength = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case MODEL_TYPE_FIELD:
modelType = parser.text();
break;
case FRAMEWORK_TYPE_FIELD:
frameworkType = FrameworkType.from(parser.text().toUpperCase(Locale.ROOT));
break;
case ALL_CONFIG_FIELD:
allConfig = parser.text();
break;
case POOLING_MODE_FIELD:
poolingMode = PoolingMode.from(parser.text().toUpperCase(Locale.ROOT));
break;

Check warning on line 81 in common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java#L80-L81

Added lines #L80 - L81 were not covered by tests
case NORMALIZE_RESULT_FIELD:
normalizeResult = parser.booleanValue();
break;

Check warning on line 84 in common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java#L83-L84

Added lines #L83 - L84 were not covered by tests
case MODEL_MAX_LENGTH_FIELD:
modelMaxLength = parser.intValue();
break;

Check warning on line 87 in common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java#L86-L87

Added lines #L86 - L87 were not covered by tests
default:
parser.skipChildren();
break;
}
}
return new QuestionAnsweringModelConfig(modelType, frameworkType, allConfig, poolingMode, normalizeResult, modelMaxLength);
}

@Override
public String getWriteableName() {
return PARSE_FIELD_NAME;
}

public QuestionAnsweringModelConfig(StreamInput in) throws IOException{
super(in);
frameworkType = in.readEnum(FrameworkType.class);
if (in.readBoolean()) {
poolingMode = in.readEnum(PoolingMode.class);

Check warning on line 105 in common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java#L105

Added line #L105 was not covered by tests
} else {
poolingMode = null;
}
normalizeResult = in.readBoolean();
modelMaxLength = in.readOptionalInt();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeEnum(frameworkType);
if (poolingMode != null) {
out.writeBoolean(true);
out.writeEnum(poolingMode);

Check warning on line 119 in common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java#L118-L119

Added lines #L118 - L119 were not covered by tests
} else {
out.writeBoolean(false);
}
out.writeBoolean(normalizeResult);
out.writeOptionalInt(modelMaxLength);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (modelType != null) {
builder.field(MODEL_TYPE_FIELD, modelType);
}
if (frameworkType != null) {
builder.field(FRAMEWORK_TYPE_FIELD, frameworkType);
}
if (allConfig != null) {
builder.field(ALL_CONFIG_FIELD, allConfig);
}
if (modelMaxLength != null) {
builder.field(MODEL_MAX_LENGTH_FIELD, modelMaxLength);

Check warning on line 140 in common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java#L140

Added line #L140 was not covered by tests
}
if (poolingMode != null) {
builder.field(POOLING_MODE_FIELD, poolingMode);

Check warning on line 143 in common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java#L143

Added line #L143 was not covered by tests
}
if (normalizeResult) {
builder.field(NORMALIZE_RESULT_FIELD, normalizeResult);

Check warning on line 146 in common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java#L146

Added line #L146 was not covered by tests
}
builder.endObject();
return builder;
}

public enum PoolingMode {
MEAN("mean"),
MEAN_SQRT_LEN("mean_sqrt_len"),
MAX("max"),
WEIGHTED_MEAN("weightedmean"),
CLS("cls"),
LAST_TOKEN("lasttoken");

Check warning on line 158 in common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java#L152-L158

Added lines #L152 - L158 were not covered by tests

private String name;

public String getName() {
return name;

Check warning on line 163 in common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java#L163

Added line #L163 was not covered by tests
}
PoolingMode(String name) {
this.name = name;
}

Check warning on line 167 in common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java#L165-L167

Added lines #L165 - L167 were not covered by tests

public static PoolingMode from(String value) {
try {
return PoolingMode.valueOf(value.toUpperCase(Locale.ROOT));
} catch (Exception e) {
throw new IllegalArgumentException("Wrong pooling method");

Check warning on line 173 in common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java#L171-L173

Added lines #L171 - L173 were not covered by tests
}
}
}
public enum FrameworkType {
HUGGINGFACE_TRANSFORMERS,
SENTENCE_TRANSFORMERS,
HUGGINGFACE_TRANSFORMERS_NEURON;

public static FrameworkType from(String value) {
try {
return FrameworkType.valueOf(value.toUpperCase(Locale.ROOT));
} catch (Exception e) {
throw new IllegalArgumentException("Wrong framework type");
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.opensearch.ml.common.controller.MLRateLimiter;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MetricsCorrelationModelConfig;
import org.opensearch.ml.common.model.QuestionAnsweringModelConfig;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;

import java.io.IOException;
Expand All @@ -40,6 +41,8 @@
public class MLRegisterModelInput implements ToXContentObject, Writeable {

public static final String FUNCTION_NAME_FIELD = "function_name";
public static final String MODEL_TASK_TYPE_FIELD = "model_task_type";

public static final String NAME_FIELD = "name";
public static final String MODEL_GROUP_ID_FIELD = "model_group_id";
public static final String DESCRIPTION_FIELD = "description";
Expand Down Expand Up @@ -160,6 +163,8 @@ public MLRegisterModelInput(StreamInput in) throws IOException {
if (in.readBoolean()) {
if (this.functionName.equals(FunctionName.METRICS_CORRELATION)) {
this.modelConfig = new MetricsCorrelationModelConfig(in);
} else if (this.functionName.equals(FunctionName.QUESTION_ANSWERING)) {
this.modelConfig = new QuestionAnsweringModelConfig(in);

Check warning on line 167 in common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java#L167

Added line #L167 was not covered by tests
} else {
this.modelConfig = new TextEmbeddingModelConfig(in);
}
Expand Down Expand Up @@ -334,6 +339,7 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName
String fieldName = parser.currentName();
parser.nextToken();
switch (fieldName) {
case MODEL_TASK_TYPE_FIELD:
case FUNCTION_NAME_FIELD:
functionName = FunctionName.from(parser.text().toUpperCase(Locale.ROOT));
break;
Expand All @@ -359,7 +365,11 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName
modelFormat = MLModelFormat.from(parser.text().toUpperCase(Locale.ROOT));
break;
case MODEL_CONFIG_FIELD:
modelConfig = TextEmbeddingModelConfig.parse(parser);
if (FunctionName.QUESTION_ANSWERING.equals(functionName)) {
modelConfig = QuestionAnsweringModelConfig.parse(parser);

Check warning on line 369 in common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java#L369

Added line #L369 was not covered by tests
} else {
modelConfig = TextEmbeddingModelConfig.parse(parser);
}
break;
case CONNECTOR_FIELD:
connector = createConnector(parser);
Expand Down Expand Up @@ -429,6 +439,7 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo
parser.nextToken();

switch (fieldName) {
case MODEL_TASK_TYPE_FIELD:
case FUNCTION_NAME_FIELD:
functionName = FunctionName.from(parser.text().toUpperCase(Locale.ROOT));
break;
Expand Down Expand Up @@ -466,7 +477,11 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo
modelFormat = MLModelFormat.from(parser.text().toUpperCase(Locale.ROOT));
break;
case MODEL_CONFIG_FIELD:
modelConfig = TextEmbeddingModelConfig.parse(parser);
if (FunctionName.QUESTION_ANSWERING.equals(functionName)) {
modelConfig = QuestionAnsweringModelConfig.parse(parser);

Check warning on line 481 in common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java#L481

Added line #L481 was not covered by tests
} else {
modelConfig = TextEmbeddingModelConfig.parse(parser);
}
break;
case MODEL_NODE_IDS_FIELD:
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import org.opensearch.ml.common.controller.MLRateLimiter;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.model.MetricsCorrelationModelConfig;
import org.opensearch.ml.common.model.QuestionAnsweringModelConfig;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;

import java.io.IOException;
Expand All @@ -35,6 +37,7 @@
public class MLRegisterModelMetaInput implements ToXContentObject, Writeable {

public static final String FUNCTION_NAME_FIELD = "function_name";
public static final String MODEL_TASK_TYPE_FIELD = "model_task_type";
public static final String MODEL_NAME_FIELD = "name"; // mandatory
public static final String DESCRIPTION_FIELD = "description"; // optional
public static final String IS_ENABLED_FIELD = "is_enabled"; // optional
Expand Down Expand Up @@ -144,7 +147,11 @@ public MLRegisterModelMetaInput(StreamInput in) throws IOException {
this.modelContentSizeInBytes = in.readOptionalLong();
this.modelContentHashValue = in.readString();
if (in.readBoolean()) {
modelConfig = new TextEmbeddingModelConfig(in);
if (this.functionName.equals(FunctionName.QUESTION_ANSWERING)) {
this.modelConfig = new QuestionAnsweringModelConfig(in);

Check warning on line 151 in common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java#L151

Added line #L151 was not covered by tests
} else {
this.modelConfig = new TextEmbeddingModelConfig(in);
}
}
this.totalChunks = in.readInt();
this.backendRoles = in.readOptionalStringList();
Expand Down Expand Up @@ -298,6 +305,7 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc
case MODEL_NAME_FIELD:
name = parser.text();
break;
case MODEL_TASK_TYPE_FIELD:
case FUNCTION_NAME_FIELD:
functionName = FunctionName.from(parser.text());
break;
Expand Down Expand Up @@ -329,7 +337,11 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc
modelContentHashValue = parser.text();
break;
case MODEL_CONFIG_FIELD:
modelConfig = TextEmbeddingModelConfig.parse(parser);
if (FunctionName.QUESTION_ANSWERING.equals(functionName)) {
modelConfig = QuestionAnsweringModelConfig.parse(parser);

Check warning on line 341 in common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java#L341

Added line #L341 was not covered by tests
} else {
modelConfig = TextEmbeddingModelConfig.parse(parser);
}
break;
case TOTAL_CHUNKS_FIELD:
totalChunks = parser.intValue(false);
Expand Down
Loading

0 comments on commit 4de3643

Please sign in to comment.