From 247a1c77928b2f25ab50db99ff7587b87f0e4433 Mon Sep 17 00:00:00 2001 From: br3no Date: Mon, 19 Feb 2024 23:14:55 +0100 Subject: [PATCH] after latest review Signed-off-by: br3no --- .../ml/common/input/nlp/TextDocsMLInput.java | 15 +------ .../common/input/nlp/TextDocsMLInputTest.java | 40 ++++--------------- .../model/TextEmbeddingModelConfigTests.java | 2 +- .../engine/algorithms/TextEmbeddingModel.java | 14 +++---- .../TextEmbeddingDenseModelTest.java | 37 ----------------- 5 files changed, 15 insertions(+), 93 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java index 304f837488..deeb5ef81f 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java @@ -5,15 +5,11 @@ package org.opensearch.ml.common.input.nlp; -import java.util.Locale; -import java.util.Map; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters; -import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters.EmbeddingContentType; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.input.MLInput; @@ -22,12 +18,11 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; -import org.opensearch.ml.common.utils.StringUtils; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; /** - * ML input class which supports a list of text docs. + * ML input class which supports a list fo text docs. * This class can be used for TEXT_EMBEDDING model. */ @org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.TEXT_EMBEDDING, FunctionName.SPARSE_ENCODING, FunctionName.SPARSE_TOKENIZE}) @@ -129,14 +124,6 @@ public TextDocsMLInput(XContentParser parser, FunctionName functionName) throws case RESULT_FILTER_FIELD: resultFilter = ModelResultFilter.parse(parser); break; - case ML_PARAMETERS_FIELD: - Map parameters = StringUtils.getParameterMap(parser.map()); - if (!parameters.containsKey(AsymmetricTextEmbeddingParameters.EMBEDDING_CONTENT_TYPE_FIELD)) { - throw new IllegalArgumentException(String.format(Locale.ROOT, "Only accepted parameter is `%s`, which can have the values `query` or `passage`.", AsymmetricTextEmbeddingParameters.EMBEDDING_CONTENT_TYPE_FIELD)); - } - EmbeddingContentType embeddingContentType = EmbeddingContentType.valueOf(parameters.get(AsymmetricTextEmbeddingParameters.EMBEDDING_CONTENT_TYPE_FIELD).toUpperCase(Locale.ROOT)); - this.parameters = new AsymmetricTextEmbeddingParameters(embeddingContentType); - break; default: parser.skipChildren(); break; diff --git a/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java index f51ad5c0b3..8819786fbe 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java @@ -12,8 +12,6 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters; -import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters.EmbeddingContentType; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.input.MLInput; @@ -41,9 +39,9 @@ public class TextDocsMLInputTest { @Before public void setUp() throws Exception { ModelResultFilter resultFilter = ModelResultFilter.builder().returnBytes(true).returnNumber(true) - .targetResponse(Arrays.asList("field1")).targetResponsePositions(Arrays.asList(2)).build(); + .targetResponse(Arrays.asList("field1")).targetResponsePositions(Arrays.asList(2)).build(); MLInputDataset inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList("doc1", "doc2")) - .resultFilter(resultFilter).build(); + .resultFilter(resultFilter).build(); input = new TextDocsMLInput(algorithm, inputDataset); } @@ -68,25 +66,13 @@ public void parseTextDocsMLInput_NewWay() throws IOException { parseMLInput(jsonStr, 2); } - @Test - public void serializationRoundTrip() throws IOException { - MLInput mlInput = TextDocsMLInput.builder().inputDataset( - TextDocsInputDataSet.builder().docs(Arrays.asList("doc1", "doc2")).build()).algorithm(algorithm) - .parameters(new AsymmetricTextEmbeddingParameters(EmbeddingContentType.QUERY)) - .build(); - - String mlInput_json = mlInput.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS).toString(); - - MLInput mlInput_2 = parseMLInputJson(mlInput_json); - - assertEquals(mlInput.getParameters(), mlInput_2.getParameters()); - assertEquals(mlInput.getAlgorithm(), mlInput_2.getAlgorithm()); - assertEquals(mlInput.getInputDataset().getInputDataType(), mlInput_2.getInputDataset().getInputDataType()); - - } - private void parseMLInput(String jsonStr, int docSize) throws IOException { - TextDocsMLInput parsedInput = parseMLInputJson(jsonStr); + XContentParser parser = XContentType.JSON.xContent() + .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); + + MLInput parsedInput = MLInput.parse(parser, input.getFunctionName().name()); assertTrue(parsedInput instanceof TextDocsMLInput); assertEquals(input.getFunctionName(), parsedInput.getFunctionName()); assertEquals(input.getInputDataset().getInputDataType(), parsedInput.getInputDataset().getInputDataType()); @@ -102,14 +88,4 @@ private void parseMLInput(String jsonStr, int docSize) throws IOException { assertTrue(inputDataset.getResultFilter().isReturnNumber()); } - private TextDocsMLInput parseMLInputJson(String jsonStr) throws IOException { - XContentParser parser = XContentType.JSON.xContent() - .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); - parser.nextToken(); - - TextDocsMLInput parsedInput = (TextDocsMLInput) MLInput.parse(parser, input.getFunctionName().name()); - return parsedInput; - } - } diff --git a/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java b/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java index 2cc86d7500..9bc97f7c9f 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java @@ -85,7 +85,7 @@ public void nullFields_FrameworkType() { @Test public void parse() throws IOException { - String content = "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}"; + String content = "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\",\"query_prefix\":\"query: \",\"passage_prefix\":\"passage: \"}"; TestHelper.testParseFromString(config, content, function); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java index 89f1b96960..a5a3a8fb50 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java @@ -59,13 +59,9 @@ private boolean isAsymmetricModel(MLAlgoParams mlParams) { } // Passed all checks return true; - } else if (mlParams != null) { - // mlParams is not null and not an instance of AsymmetricTextEmbeddingParameters. - // Should never happen. - throw new IllegalArgumentException("TEXT_EMBEDDING algorithm only supports AsymmetricTextEmbeddingParameters."); } - // mlParams is null, but the model is asymmetric. + // no AsymmetricTextEmbeddingParameters passed, but the model is asymmetric. if (modelConfig != null && (((TextEmbeddingModelConfig) modelConfig).getPassagePrefix() != null || ((TextEmbeddingModelConfig) modelConfig).getQueryPrefix() != null)) { @@ -80,11 +76,11 @@ private boolean isAsymmetricModel(MLAlgoParams mlParams) { private TextDocsInputDataSet addPrefixesToData(AsymmetricTextEmbeddingParameters mlParams, TextDocsInputDataSet inputDataSet) { // Asymmetric embedding models typically work with "mini-prompts" that prime the model to embed a text // as a query or as a passage. Here we apply the prompt as defined in the model configuration. We default - // to passage embedding. + // to query embedding. TextEmbeddingModelConfig modelConfig = (TextEmbeddingModelConfig) this.modelConfig; - String prefix = mlParams.getEmbeddingContentType() == EmbeddingContentType.QUERY - ? modelConfig.getQueryPrefix() - : modelConfig.getPassagePrefix(); + String prefix = mlParams.getEmbeddingContentType() == EmbeddingContentType.PASSAGE + ? modelConfig.getPassagePrefix() + : modelConfig.getQueryPrefix(); if (prefix != null) { List prefixedDocs = inputDataSet.getDocs().stream().map(s -> prefix + s).collect(Collectors.toList()); return TextDocsInputDataSet.builder().docs(prefixedDocs).build(); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java index 6e638c4f7e..75bdde5bbe 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java @@ -38,8 +38,6 @@ import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.MLInput; -import org.opensearch.ml.common.input.parameter.clustering.KMeansParams; -import org.opensearch.ml.common.input.parameter.clustering.KMeansParams.DistanceType; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; @@ -496,41 +494,6 @@ public void initModel_predict_TorchScript_SentenceTransformer_SmallModel_With_As } - @Test - public void initModel_predict_TorchScript_SentenceTransformer_SmallModel_With_Asymmetric_Prompts_SadPath3() throws URISyntaxException { - // wrong parameter type - Map params = new HashMap<>(); - params.put(MODEL_HELPER, modelHelper); - params.put(MODEL_ZIP_FILE, new File(getClass().getResource("traced_small_model.zip").toURI())); - params.put(ML_ENGINE, mlEngine); - - TextEmbeddingModelConfig symmetricModelConfig = this.modelConfig.toBuilder().embeddingDimension(768).build(); - MLModel symmetricSmallModel = model.toBuilder().modelConfig(symmetricModelConfig).build(); - textEmbeddingDenseModel.initModel(symmetricSmallModel, params, encryptor); - - MLInput asymmetricMlInputQueries = MLInput - .builder() - .algorithm(FunctionName.TEXT_EMBEDDING) - .inputDataset( - TextDocsInputDataSet.builder().docs(Arrays.asList("what is the meaning of life?", "who won this year's us open")).build() - ) - .parameters(new KMeansParams(0, 1, DistanceType.COSINE)) - .build(); - - try { - textEmbeddingDenseModel.predict(asymmetricMlInputQueries); - } catch (MLException e) { - assertEquals(IllegalArgumentException.class, e.getCause().getClass()); - assertEquals("TEXT_EMBEDDING algorithm only supports AsymmetricTextEmbeddingParameters.", e.getCause().getMessage()); - return; - } finally { - textEmbeddingDenseModel.close(); - } - - fail("Expected exception not thrown"); - - } - @Test public void initModel_NullModelZipFile() { exceptionRule.expect(IllegalArgumentException.class);