Skip to content

Commit

Permalink
after latest review
Browse files Browse the repository at this point in the history
Signed-off-by: br3no <[email protected]>
  • Loading branch information
br3no committed Feb 19, 2024
1 parent 2a83cce commit 247a1c7
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,11 @@

package org.opensearch.ml.common.input.nlp;

import java.util.Locale;
import java.util.Map;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters;
import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters.EmbeddingContentType;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;
Expand All @@ -22,12 +18,11 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.opensearch.ml.common.utils.StringUtils;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

/**
* ML input class which supports a list of text docs.
* ML input class which supports a list fo text docs.
* This class can be used for TEXT_EMBEDDING model.
*/
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.TEXT_EMBEDDING, FunctionName.SPARSE_ENCODING, FunctionName.SPARSE_TOKENIZE})
Expand Down Expand Up @@ -129,14 +124,6 @@ public TextDocsMLInput(XContentParser parser, FunctionName functionName) throws
case RESULT_FILTER_FIELD:
resultFilter = ModelResultFilter.parse(parser);
break;
case ML_PARAMETERS_FIELD:
Map<String, String> parameters = StringUtils.getParameterMap(parser.map());
if (!parameters.containsKey(AsymmetricTextEmbeddingParameters.EMBEDDING_CONTENT_TYPE_FIELD)) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "Only accepted parameter is `%s`, which can have the values `query` or `passage`.", AsymmetricTextEmbeddingParameters.EMBEDDING_CONTENT_TYPE_FIELD));
}
EmbeddingContentType embeddingContentType = EmbeddingContentType.valueOf(parameters.get(AsymmetricTextEmbeddingParameters.EMBEDDING_CONTENT_TYPE_FIELD).toUpperCase(Locale.ROOT));
this.parameters = new AsymmetricTextEmbeddingParameters(embeddingContentType);
break;
default:
parser.skipChildren();
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters;
import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters.EmbeddingContentType;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;
Expand Down Expand Up @@ -41,9 +39,9 @@ public class TextDocsMLInputTest {
@Before
public void setUp() throws Exception {
ModelResultFilter resultFilter = ModelResultFilter.builder().returnBytes(true).returnNumber(true)
.targetResponse(Arrays.asList("field1")).targetResponsePositions(Arrays.asList(2)).build();
.targetResponse(Arrays.asList("field1")).targetResponsePositions(Arrays.asList(2)).build();
MLInputDataset inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList("doc1", "doc2"))
.resultFilter(resultFilter).build();
.resultFilter(resultFilter).build();
input = new TextDocsMLInput(algorithm, inputDataset);
}

Expand All @@ -68,25 +66,13 @@ public void parseTextDocsMLInput_NewWay() throws IOException {
parseMLInput(jsonStr, 2);
}

@Test
public void serializationRoundTrip() throws IOException {
MLInput mlInput = TextDocsMLInput.builder().inputDataset(
TextDocsInputDataSet.builder().docs(Arrays.asList("doc1", "doc2")).build()).algorithm(algorithm)
.parameters(new AsymmetricTextEmbeddingParameters(EmbeddingContentType.QUERY))
.build();

String mlInput_json = mlInput.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS).toString();

MLInput mlInput_2 = parseMLInputJson(mlInput_json);

assertEquals(mlInput.getParameters(), mlInput_2.getParameters());
assertEquals(mlInput.getAlgorithm(), mlInput_2.getAlgorithm());
assertEquals(mlInput.getInputDataset().getInputDataType(), mlInput_2.getInputDataset().getInputDataType());

}

private void parseMLInput(String jsonStr, int docSize) throws IOException {
TextDocsMLInput parsedInput = parseMLInputJson(jsonStr);
XContentParser parser = XContentType.JSON.xContent()
.createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();

MLInput parsedInput = MLInput.parse(parser, input.getFunctionName().name());
assertTrue(parsedInput instanceof TextDocsMLInput);
assertEquals(input.getFunctionName(), parsedInput.getFunctionName());
assertEquals(input.getInputDataset().getInputDataType(), parsedInput.getInputDataset().getInputDataType());
Expand All @@ -102,14 +88,4 @@ private void parseMLInput(String jsonStr, int docSize) throws IOException {
assertTrue(inputDataset.getResultFilter().isReturnNumber());
}

private TextDocsMLInput parseMLInputJson(String jsonStr) throws IOException {
XContentParser parser = XContentType.JSON.xContent()
.createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();

TextDocsMLInput parsedInput = (TextDocsMLInput) MLInput.parse(parser, input.getFunctionName().name());
return parsedInput;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public void nullFields_FrameworkType() {

@Test
public void parse() throws IOException {
String content = "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}";
String content = "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\",\"query_prefix\":\"query: \",\"passage_prefix\":\"passage: \"}";
TestHelper.testParseFromString(config, content, function);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,9 @@ private boolean isAsymmetricModel(MLAlgoParams mlParams) {
}
// Passed all checks
return true;
} else if (mlParams != null) {
// mlParams is not null and not an instance of AsymmetricTextEmbeddingParameters.
// Should never happen.
throw new IllegalArgumentException("TEXT_EMBEDDING algorithm only supports AsymmetricTextEmbeddingParameters.");
}

// mlParams is null, but the model is asymmetric.
// no AsymmetricTextEmbeddingParameters passed, but the model is asymmetric.
if (modelConfig != null
&& (((TextEmbeddingModelConfig) modelConfig).getPassagePrefix() != null
|| ((TextEmbeddingModelConfig) modelConfig).getQueryPrefix() != null)) {
Expand All @@ -80,11 +76,11 @@ private boolean isAsymmetricModel(MLAlgoParams mlParams) {
private TextDocsInputDataSet addPrefixesToData(AsymmetricTextEmbeddingParameters mlParams, TextDocsInputDataSet inputDataSet) {
// Asymmetric embedding models typically work with "mini-prompts" that prime the model to embed a text
// as a query or as a passage. Here we apply the prompt as defined in the model configuration. We default
// to passage embedding.
// to query embedding.
TextEmbeddingModelConfig modelConfig = (TextEmbeddingModelConfig) this.modelConfig;
String prefix = mlParams.getEmbeddingContentType() == EmbeddingContentType.QUERY
? modelConfig.getQueryPrefix()
: modelConfig.getPassagePrefix();
String prefix = mlParams.getEmbeddingContentType() == EmbeddingContentType.PASSAGE
? modelConfig.getPassagePrefix()
: modelConfig.getQueryPrefix();
if (prefix != null) {
List<String> prefixedDocs = inputDataSet.getDocs().stream().map(s -> prefix + s).collect(Collectors.toList());
return TextDocsInputDataSet.builder().docs(prefixedDocs).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.clustering.KMeansParams;
import org.opensearch.ml.common.input.parameter.clustering.KMeansParams.DistanceType;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
Expand Down Expand Up @@ -496,41 +494,6 @@ public void initModel_predict_TorchScript_SentenceTransformer_SmallModel_With_As

}

@Test
public void initModel_predict_TorchScript_SentenceTransformer_SmallModel_With_Asymmetric_Prompts_SadPath3() throws URISyntaxException {
// wrong parameter type
Map<String, Object> params = new HashMap<>();
params.put(MODEL_HELPER, modelHelper);
params.put(MODEL_ZIP_FILE, new File(getClass().getResource("traced_small_model.zip").toURI()));
params.put(ML_ENGINE, mlEngine);

TextEmbeddingModelConfig symmetricModelConfig = this.modelConfig.toBuilder().embeddingDimension(768).build();
MLModel symmetricSmallModel = model.toBuilder().modelConfig(symmetricModelConfig).build();
textEmbeddingDenseModel.initModel(symmetricSmallModel, params, encryptor);

MLInput asymmetricMlInputQueries = MLInput
.builder()
.algorithm(FunctionName.TEXT_EMBEDDING)
.inputDataset(
TextDocsInputDataSet.builder().docs(Arrays.asList("what is the meaning of life?", "who won this year's us open")).build()
)
.parameters(new KMeansParams(0, 1, DistanceType.COSINE))
.build();

try {
textEmbeddingDenseModel.predict(asymmetricMlInputQueries);
} catch (MLException e) {
assertEquals(IllegalArgumentException.class, e.getCause().getClass());
assertEquals("TEXT_EMBEDDING algorithm only supports AsymmetricTextEmbeddingParameters.", e.getCause().getMessage());
return;
} finally {
textEmbeddingDenseModel.close();
}

fail("Expected exception not thrown");

}

@Test
public void initModel_NullModelZipFile() {
exceptionRule.expect(IllegalArgumentException.class);
Expand Down

0 comments on commit 247a1c7

Please sign in to comment.