From cd9cc9ee9d7feffe8a9083be9853b8951ab903d7 Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Thu, 25 Apr 2024 06:20:46 +0200 Subject: [PATCH] Fixes #2317 predict api not working with asymmetric models (#2318) * Fixes #2317 predict api not working with asymmetric models Signed-off-by: br3no * Adding unit test code path for the parsing of the parameter. Signed-off-by: br3no * Removing involuntary import of guava Signed-off-by: br3no * Refactor package of AsymmetricTextEmbeddingParameters The MLCommonsClassLoader expects all MLAlgoParameters to be in the "org.opensearch.ml.common.input.parameter" package. Signed-off-by: br3no * fixing unit test after package refactoring Signed-off-by: br3no --------- Signed-off-by: br3no (cherry picked from commit 8425a658958cfbf95f9f116d18d8d8f57f6d8699) --- .../ml/common/input/nlp/TextDocsMLInput.java | 9 +++++++++ .../AsymmetricTextEmbeddingParameters.java | 4 ++-- .../AsymmetricTextEmbeddingParametersTest.java | 5 +++-- .../ml/common/input/nlp/TextDocsMLInputTest.java | 14 ++++++++++++-- .../ml/engine/algorithms/TextEmbeddingModel.java | 4 ++-- .../TextEmbeddingDenseModelTest.java | 4 ++-- .../ml/plugin/MachineLearningPlugin.java | 2 +- 7 files changed, 31 insertions(+), 11 deletions(-) rename common/src/main/java/org/opensearch/ml/common/{dataset => input/parameter/textembedding}/AsymmetricTextEmbeddingParameters.java (96%) diff --git a/common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java index deeb5ef81f..3fac064e0c 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java @@ -5,6 +5,7 @@ package org.opensearch.ml.common.input.nlp; +import java.util.Locale; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentBuilder; @@ -13,6 +14,7 @@ import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.output.model.ModelResultFilter; import java.io.IOException; @@ -82,6 +84,7 @@ public TextDocsMLInput(XContentParser parser, FunctionName functionName) throws List docs = new ArrayList<>(); ModelResultFilter resultFilter = null; + MLAlgoParams mlParameters = null; boolean returnBytes = false; boolean returnNumber = true; List targetResponse = new ArrayList<>(); @@ -93,6 +96,10 @@ public TextDocsMLInput(XContentParser parser, FunctionName functionName) throws parser.nextToken(); switch (fieldName) { + case ML_PARAMETERS_FIELD: + mlParameters = parser.namedObject(MLAlgoParams.class, this.algorithm.name().toUpperCase( + Locale.ROOT), null); + break; case RETURN_BYTES_FIELD: returnBytes = parser.booleanValue(); break; @@ -137,6 +144,8 @@ public TextDocsMLInput(XContentParser parser, FunctionName functionName) throws throw new IllegalArgumentException("Empty text docs"); } inputDataset = new TextDocsInputDataSet(docs, filter); + + this.parameters = mlParameters; } } diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParameters.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AsymmetricTextEmbeddingParameters.java similarity index 96% rename from common/src/main/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParameters.java rename to common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AsymmetricTextEmbeddingParameters.java index 0c03d0b3be..be7c139efa 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParameters.java +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AsymmetricTextEmbeddingParameters.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ml.common.dataset; +package org.opensearch.ml.common.input.parameter.textembedding; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; @@ -33,7 +33,7 @@ * `query_prefix` and `passage_prefix` configuration parameters. */ @Data -@MLAlgoParameter(algorithms = { FunctionName.TEXT_EMBEDDING }) +@MLAlgoParameter(algorithms={FunctionName.TEXT_EMBEDDING}) public class AsymmetricTextEmbeddingParameters implements MLAlgoParams { public enum EmbeddingContentType { diff --git a/common/src/test/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParametersTest.java b/common/src/test/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParametersTest.java index a7a27c00ee..df50348ff5 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParametersTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParametersTest.java @@ -11,7 +11,8 @@ import java.io.IOException; import java.util.function.Function; -import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters.EmbeddingContentType; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType; import static org.junit.Assert.assertEquals; import static org.opensearch.ml.common.TestHelper.contentObjectToString; @@ -52,7 +53,7 @@ public void parse_AsymmetricTextEmbeddingParameters_Passage() throws IOException @Test public void parse_AsymmetricTextEmbeddingParameters_Invalid() throws IOException { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("No enum constant org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters.EmbeddingContentType.FU"); + exceptionRule.expectMessage("No enum constant org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType.FU"); String paramsStr = contentObjectToString(params); testParseFromString(params, paramsStr.replace("QUERY","fu"), function); } diff --git a/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java index fc052ea787..397769146c 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java @@ -1,5 +1,7 @@ package org.opensearch.ml.common.input.nlp; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -16,6 +18,7 @@ import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelResultFilter; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; import org.opensearch.search.SearchModule; import java.io.IOException; @@ -65,10 +68,17 @@ public void parseTextDocsMLInput_NewWay() throws IOException { parseMLInput(jsonStr, 2); } + @Test + public void parseTextDocsMLInput_WithParameters() throws IOException { + String jsonStr = "{\"text_docs\":[\"doc1\",\"doc2\"],\"result_filter\":{\"return_bytes\":true,\"return_number\":true,\"target_response\":[\"field1\"], \"target_response_positions\": [2]}, \"parameters\" : {\"content_type\": \"passage\"}}"; + parseMLInput(jsonStr, 2); + } + private void parseMLInput(String jsonStr, int docSize) throws IOException { XContentParser parser = XContentType.JSON.xContent() - .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + .createParser(new NamedXContentRegistry(Stream.concat(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents().stream(), Stream.of( + AsymmetricTextEmbeddingParameters.XCONTENT_REGISTRY)).collect(Collectors.toList())), null, jsonStr); parser.nextToken(); MLInput parsedInput = MLInput.parse(parser, input.getFunctionName().name()); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java index a5a3a8fb50..33a69697d1 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java @@ -6,12 +6,12 @@ import java.util.Map; import java.util.stream.Collectors; -import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters; -import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters.EmbeddingContentType; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.parameter.MLAlgoParams; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.output.model.ModelResultFilter; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java index 75bdde5bbe..6c72a97fcb 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java @@ -33,11 +33,11 @@ import org.opensearch.ResourceNotFoundException; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; -import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters; -import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters.EmbeddingContentType; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 09b419f6a4..f6ffbe7853 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -91,7 +91,6 @@ import org.opensearch.ml.cluster.MLCommonsClusterEventListener; import org.opensearch.ml.cluster.MLCommonsClusterManagerEventListener; import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters; import org.opensearch.ml.common.input.execute.anomalylocalization.AnomalyLocalizationInput; import org.opensearch.ml.common.input.execute.metricscorrelation.MetricsCorrelationInput; import org.opensearch.ml.common.input.execute.samplecalculator.LocalSampleCalculatorInput; @@ -103,6 +102,7 @@ import org.opensearch.ml.common.input.parameter.regression.LinearRegressionParams; import org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams; import org.opensearch.ml.common.input.parameter.sample.SampleAlgoParams; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.spi.MLCommonsExtension; import org.opensearch.ml.common.spi.memory.Memory;