-
Notifications
You must be signed in to change notification settings - Fork 138
Commit
Signed-off-by: Bhavana Ramaram <[email protected]>
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.model; | ||
|
||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.Setter; | ||
import org.opensearch.core.ParseField; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
import org.opensearch.core.xcontent.XContentBuilder; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
import org.opensearch.ml.common.FunctionName; | ||
|
||
import java.io.IOException; | ||
import java.util.Locale; | ||
|
||
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; | ||
|
||
@Setter | ||
@Getter | ||
public class QuestionAnsweringModelConfig extends MLModelConfig { | ||
public static final String PARSE_FIELD_NAME = FunctionName.QUESTION_ANSWERING.name(); | ||
public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( | ||
QuestionAnsweringModelConfig.class, | ||
new ParseField(PARSE_FIELD_NAME), | ||
it -> parse(it) | ||
Check warning on line 31 in common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java Codecov / codecov/patchcommon/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java#L31
|
||
); | ||
public static final String FRAMEWORK_TYPE_FIELD = "framework_type"; | ||
public static final String POOLING_MODE_FIELD = "pooling_mode"; | ||
public static final String NORMALIZE_RESULT_FIELD = "normalize_result"; | ||
public static final String MODEL_MAX_LENGTH_FIELD = "model_max_length"; | ||
|
||
private final FrameworkType frameworkType; | ||
private final PoolingMode poolingMode; | ||
private final boolean normalizeResult; | ||
private final Integer modelMaxLength; | ||
|
||
@Builder(toBuilder = true) | ||
public QuestionAnsweringModelConfig(String modelType, FrameworkType frameworkType, String allConfig, | ||
PoolingMode poolingMode, boolean normalizeResult, Integer modelMaxLength) { | ||
super(modelType, allConfig); | ||
if (frameworkType == null) { | ||
throw new IllegalArgumentException("framework type is null"); | ||
} | ||
this.frameworkType = frameworkType; | ||
this.poolingMode = poolingMode; | ||
this.normalizeResult = normalizeResult; | ||
this.modelMaxLength = modelMaxLength; | ||
} | ||
|
||
public static QuestionAnsweringModelConfig parse(XContentParser parser) throws IOException { | ||
String modelType = null; | ||
FrameworkType frameworkType = null; | ||
String allConfig = null; | ||
PoolingMode poolingMode = null; | ||
boolean normalizeResult = false; | ||
Integer modelMaxLength = null; | ||
|
||
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); | ||
while (parser.nextToken() != XContentParser.Token.END_OBJECT) { | ||
String fieldName = parser.currentName(); | ||
parser.nextToken(); | ||
|
||
switch (fieldName) { | ||
case MODEL_TYPE_FIELD: | ||
modelType = parser.text(); | ||
break; | ||
case FRAMEWORK_TYPE_FIELD: | ||
frameworkType = FrameworkType.from(parser.text().toUpperCase(Locale.ROOT)); | ||
break; | ||
case ALL_CONFIG_FIELD: | ||
allConfig = parser.text(); | ||
break; | ||
case POOLING_MODE_FIELD: | ||
poolingMode = PoolingMode.from(parser.text().toUpperCase(Locale.ROOT)); | ||
break; | ||
Check warning on line 81 in common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java Codecov / codecov/patchcommon/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java#L80-L81
|
||
case NORMALIZE_RESULT_FIELD: | ||
normalizeResult = parser.booleanValue(); | ||
break; | ||
Check warning on line 84 in common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java Codecov / codecov/patchcommon/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java#L83-L84
|
||
case MODEL_MAX_LENGTH_FIELD: | ||
modelMaxLength = parser.intValue(); | ||
break; | ||
Check warning on line 87 in common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java Codecov / codecov/patchcommon/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java#L86-L87
|
||
default: | ||
parser.skipChildren(); | ||
break; | ||
} | ||
} | ||
return new QuestionAnsweringModelConfig(modelType, frameworkType, allConfig, poolingMode, normalizeResult, modelMaxLength); | ||
} | ||
|
||
@Override | ||
public String getWriteableName() { | ||
return PARSE_FIELD_NAME; | ||
} | ||
|
||
public QuestionAnsweringModelConfig(StreamInput in) throws IOException{ | ||
super(in); | ||
frameworkType = in.readEnum(FrameworkType.class); | ||
if (in.readBoolean()) { | ||
poolingMode = in.readEnum(PoolingMode.class); | ||
Check warning on line 105 in common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java Codecov / codecov/patchcommon/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java#L105
|
||
} else { | ||
poolingMode = null; | ||
} | ||
normalizeResult = in.readBoolean(); | ||
modelMaxLength = in.readOptionalInt(); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
out.writeEnum(frameworkType); | ||
if (poolingMode != null) { | ||
out.writeBoolean(true); | ||
out.writeEnum(poolingMode); | ||
Check warning on line 119 in common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java Codecov / codecov/patchcommon/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java#L118-L119
|
||
} else { | ||
out.writeBoolean(false); | ||
} | ||
out.writeBoolean(normalizeResult); | ||
out.writeOptionalInt(modelMaxLength); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startObject(); | ||
if (modelType != null) { | ||
builder.field(MODEL_TYPE_FIELD, modelType); | ||
} | ||
if (frameworkType != null) { | ||
builder.field(FRAMEWORK_TYPE_FIELD, frameworkType); | ||
} | ||
if (allConfig != null) { | ||
builder.field(ALL_CONFIG_FIELD, allConfig); | ||
} | ||
if (modelMaxLength != null) { | ||
builder.field(MODEL_MAX_LENGTH_FIELD, modelMaxLength); | ||
Check warning on line 140 in common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java Codecov / codecov/patchcommon/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java#L140
|
||
} | ||
if (poolingMode != null) { | ||
builder.field(POOLING_MODE_FIELD, poolingMode); | ||
Check warning on line 143 in common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java Codecov / codecov/patchcommon/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java#L143
|
||
} | ||
if (normalizeResult) { | ||
builder.field(NORMALIZE_RESULT_FIELD, normalizeResult); | ||
Check warning on line 146 in common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java Codecov / codecov/patchcommon/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java#L146
|
||
} | ||
builder.endObject(); | ||
return builder; | ||
} | ||
|
||
public enum PoolingMode { | ||
MEAN("mean"), | ||
MEAN_SQRT_LEN("mean_sqrt_len"), | ||
MAX("max"), | ||
WEIGHTED_MEAN("weightedmean"), | ||
CLS("cls"), | ||
LAST_TOKEN("lasttoken"); | ||
Check warning on line 158 in common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java Codecov / codecov/patchcommon/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java#L152-L158
|
||
|
||
private String name; | ||
|
||
public String getName() { | ||
return name; | ||
Check warning on line 163 in common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java Codecov / codecov/patchcommon/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java#L163
|
||
} | ||
PoolingMode(String name) { | ||
this.name = name; | ||
} | ||
Check warning on line 167 in common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java Codecov / codecov/patchcommon/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java#L165-L167
|
||
|
||
public static PoolingMode from(String value) { | ||
try { | ||
return PoolingMode.valueOf(value.toUpperCase(Locale.ROOT)); | ||
} catch (Exception e) { | ||
throw new IllegalArgumentException("Wrong pooling method"); | ||
Check warning on line 173 in common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java Codecov / codecov/patchcommon/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java#L171-L173
|
||
} | ||
} | ||
} | ||
public enum FrameworkType { | ||
HUGGINGFACE_TRANSFORMERS, | ||
SENTENCE_TRANSFORMERS, | ||
HUGGINGFACE_TRANSFORMERS_NEURON; | ||
|
||
public static FrameworkType from(String value) { | ||
try { | ||
return FrameworkType.valueOf(value.toUpperCase(Locale.ROOT)); | ||
} catch (Exception e) { | ||
throw new IllegalArgumentException("Wrong framework type"); | ||
} | ||
} | ||
} | ||
|
||
} |