diff --git a/common/src/main/java/org/opensearch/ml/common/FunctionName.java b/common/src/main/java/org/opensearch/ml/common/FunctionName.java index 96c6a58235..72810459a4 100644 --- a/common/src/main/java/org/opensearch/ml/common/FunctionName.java +++ b/common/src/main/java/org/opensearch/ml/common/FunctionName.java @@ -5,6 +5,9 @@ package org.opensearch.ml.common; +import java.util.HashSet; +import java.util.Set; + public enum FunctionName { LINEAR_REGRESSION, KMEANS, @@ -17,6 +20,7 @@ public enum FunctionName { RCF_SUMMARIZE, LOGISTIC_REGRESSION, TEXT_EMBEDDING, + TEXT_SIMILARITY, SPARSE_ENCODING, SPARSE_TOKENIZE, METRICS_CORRELATION, @@ -30,14 +34,18 @@ public static FunctionName from(String value) { } } + private static final HashSet DL_MODELS = new HashSet<>(Set.of( + TEXT_EMBEDDING, + TEXT_SIMILARITY, + SPARSE_ENCODING, + SPARSE_TOKENIZE + )); + /** * Check if model is deep learning model. * @return true for deep learning model. */ public static boolean isDLModel(FunctionName functionName) { - if (functionName == TEXT_EMBEDDING || functionName == SPARSE_ENCODING || functionName == SPARSE_TOKENIZE) { - return true; - } - return false; + return DL_MODELS.contains(functionName); } } diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataType.java b/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataType.java index 46cdb161bd..d115bba3e0 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataType.java +++ b/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataType.java @@ -9,5 +9,6 @@ public enum MLInputDataType { SEARCH_QUERY, DATA_FRAME, TEXT_DOCS, + TEXT_SIMILARITY, REMOTE } diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/TextSimilarityInputDataSet.java b/common/src/main/java/org/opensearch/ml/common/dataset/TextSimilarityInputDataSet.java new file mode 100644 index 0000000000..649146410c --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/dataset/TextSimilarityInputDataSet.java @@ -0,0 +1,75 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.common.dataset; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.annotation.InputDataSet; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.experimental.FieldDefaults; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@InputDataSet(MLInputDataType.TEXT_SIMILARITY) +public class TextSimilarityInputDataSet extends MLInputDataset { + + List textDocs; + + String queryText; + + @Builder(toBuilder = true) + public TextSimilarityInputDataSet(String queryText, List textDocs) { + super(MLInputDataType.TEXT_SIMILARITY); + Objects.requireNonNull(textDocs); + Objects.requireNonNull(queryText); + if(textDocs.isEmpty()) { + throw new IllegalArgumentException("No text documents were provided"); + } + this.textDocs = textDocs; + this.queryText = queryText; + } + + public TextSimilarityInputDataSet(StreamInput in) throws IOException { + super(MLInputDataType.TEXT_SIMILARITY); + this.queryText = in.readString(); + int size = in.readInt(); + this.textDocs = new ArrayList(); + for(int i = 0; i < size; i++) { + String context = in.readString(); + this.textDocs.add(context); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(queryText); + out.writeInt(this.textDocs.size()); + for (String doc : this.textDocs) { + out.writeString(doc); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/input/MLInput.java b/common/src/main/java/org/opensearch/ml/common/input/MLInput.java index 574f13e9c3..acd1522736 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/MLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/MLInput.java @@ -8,6 +8,7 @@ import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentBuilder; @@ -21,6 +22,7 @@ import org.opensearch.ml.common.dataset.SearchQueryInputDataset; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.search.builder.SearchSourceBuilder; @@ -55,6 +57,8 @@ public class MLInput implements Input { public static final String TARGET_RESPONSE_POSITIONS_FIELD = "target_response_positions"; // Input text sentences for text embedding model public static final String TEXT_DOCS_FIELD = "text_docs"; + // Input query text to compare against for text similarity model + public static final String QUERY_TEXT_FIELD = "query_text"; // Algorithm name protected FunctionName algorithm; @@ -157,6 +161,20 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(TARGET_RESPONSE_POSITIONS_FIELD, targetPositions.toArray(new Integer[0])); } } + break; + case TEXT_SIMILARITY: + TextSimilarityInputDataSet ds = (TextSimilarityInputDataSet) this.inputDataset; + List tdocs = ds.getTextDocs(); + String queryText = ds.getQueryText(); + builder.field(QUERY_TEXT_FIELD, queryText); + if (tdocs != null && !tdocs.isEmpty()) { + builder.startArray(TEXT_DOCS_FIELD); + for(String d : tdocs) { + builder.value(d); + } + builder.endArray(); + } + break; default: break; } @@ -186,6 +204,7 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws List targetResponse = new ArrayList<>(); List targetResponsePositions = new ArrayList<>(); List textDocs = new ArrayList<>(); + String queryText = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -233,6 +252,9 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws textDocs.add(parser.text()); } break; + case QUERY_TEXT_FIELD: + queryText = parser.text(); + break; default: parser.skipChildren(); break; @@ -243,6 +265,9 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws ModelResultFilter filter = new ModelResultFilter(returnBytes, returnNumber, targetResponse, targetResponsePositions); inputDataSet = new TextDocsInputDataSet(textDocs, filter); } + if (algorithm == FunctionName.TEXT_SIMILARITY) { + inputDataSet = new TextSimilarityInputDataSet(queryText, textDocs); + } return new MLInput(algorithm, mlParameters, searchSourceBuilder, sourceIndices, dataFrame, inputDataSet); } diff --git a/common/src/main/java/org/opensearch/ml/common/input/nlp/TextSimilarityMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/nlp/TextSimilarityMLInput.java new file mode 100644 index 0000000000..0c4d9f9a7b --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/input/nlp/TextSimilarityMLInput.java @@ -0,0 +1,116 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.common.input.nlp; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + + +/** + * MLInput which supports a text similarity algorithm + * Inputs are a query and a list of texts. Outputs are real numbers + * Use this for Cross Encoder models + */ +@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.TEXT_SIMILARITY}) +public class TextSimilarityMLInput extends MLInput { + + public TextSimilarityMLInput(FunctionName algorithm, MLInputDataset dataset) { + super(algorithm, null, dataset); + } + + public TextSimilarityMLInput(StreamInput in) throws IOException { + super(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ALGORITHM_FIELD, algorithm.name()); + if(parameters != null) { + builder.field(ML_PARAMETERS_FIELD, parameters); + } + if(inputDataset != null) { + TextSimilarityInputDataSet ds = (TextSimilarityInputDataSet) this.inputDataset; + List docs = ds.getTextDocs(); + String queryText = ds.getQueryText(); + builder.field(QUERY_TEXT_FIELD, queryText); + if (docs != null && !docs.isEmpty()) { + builder.startArray(TEXT_DOCS_FIELD); + for(String d : docs) { + builder.value(d); + } + builder.endArray(); + } + } + builder.endObject(); + return builder; + } + + public TextSimilarityMLInput(XContentParser parser, FunctionName functionName) throws IOException { + super(); + this.algorithm = functionName; + List docs = new ArrayList<>(); + String queryText = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case TEXT_DOCS_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + String context = parser.text(); + docs.add(context); + } + break; + case QUERY_TEXT_FIELD: + queryText = parser.text(); + default: + parser.skipChildren(); + break; + } + } + if(docs.isEmpty()) { + throw new IllegalArgumentException("No text documents were provided"); + } + if(queryText == null) { + throw new IllegalArgumentException("No query text was provided"); + } + inputDataset = new TextSimilarityInputDataSet(queryText, docs); + } + +} diff --git a/common/src/test/java/org/opensearch/ml/common/dataset/TextSimilarityInputDatasetTest.java b/common/src/test/java/org/opensearch/ml/common/dataset/TextSimilarityInputDatasetTest.java new file mode 100644 index 0000000000..74969d537f --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/dataset/TextSimilarityInputDatasetTest.java @@ -0,0 +1,65 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.common.dataset; + +import static org.junit.Assert.assertThrows; + +import java.io.IOException; +import java.util.List; + +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class TextSimilarityInputDatasetTest { + + @Test + public void testStreaming() throws IOException { + List docs = List.of("That is a happy dog", "it's summer"); + String queryText = "today is sunny"; + TextSimilarityInputDataSet dataset = TextSimilarityInputDataSet.builder().queryText(queryText).textDocs(docs).build(); + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + dataset.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + TextSimilarityInputDataSet newDs = (TextSimilarityInputDataSet) MLInputDataset.fromStream(in); + assert (dataset.getTextDocs().equals(newDs.getTextDocs())); + assert (dataset.getQueryText().equals(newDs.getQueryText())); + } + + @Test + public void noPairs_ThenFail() { + List docs = List.of(); + String queryText = "today is sunny"; + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> TextSimilarityInputDataSet.builder().textDocs(docs).queryText(queryText).build()); + assert (e.getMessage().equals("No text documents were provided")); + } + + @Test + public void noQuery_ThenFail() { + List docs = List.of("That is a happy dog", "it's summer"); + String queryText = null; + assertThrows(NullPointerException.class, + () -> TextSimilarityInputDataSet.builder().textDocs(docs).queryText(queryText).build()); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java index 43cb4d6fd3..d08c678634 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java @@ -32,6 +32,8 @@ import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.SearchQueryInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; +import org.opensearch.ml.common.input.nlp.TextSimilarityMLInput; import org.opensearch.ml.common.input.parameter.regression.LinearRegressionParams; import org.opensearch.ml.common.output.model.ModelResultFilter; import org.opensearch.search.SearchModule; @@ -210,4 +212,20 @@ private void readInputStream(MLInput input, Consumer verify) throws IOE verify.accept(parsedInput); } + @Test + public void testParse_TextSimilarity() throws IOException { + List docs = List.of("That is a happy dog", "it's summer"); + String queryText = "today is sunny"; + MLInputDataset dataset = TextSimilarityInputDataSet.builder().textDocs(docs).queryText(queryText).build(); + input = new TextSimilarityMLInput(FunctionName.TEXT_SIMILARITY, dataset); + MLInput inp = new MLInput(FunctionName.TEXT_SIMILARITY, null, dataset); + String expected = "{\"algorithm\":\"TEXT_SIMILARITY\",\"query_text\":\"today is sunny\",\"text_docs\":[\"That is a happy dog\",\"it's summer\"]}"; + + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + inp.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(builder); + String jsonStr = builder.toString(); + assertEquals(expected, jsonStr); + } + } diff --git a/common/src/test/java/org/opensearch/ml/common/input/nlp/TextSimilarityMLInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/nlp/TextSimilarityMLInputTest.java new file mode 100644 index 0000000000..296b939f5f --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/input/nlp/TextSimilarityMLInputTest.java @@ -0,0 +1,134 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.common.input.nlp; + +import static org.junit.Assert.assertThrows; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.search.SearchModule; + +public class TextSimilarityMLInputTest { + + MLInput input; + + private final FunctionName algorithm = FunctionName.TEXT_SIMILARITY; + + @Before + public void setup() { + List docs = List.of("That is a happy dog", "it's summer"); + String queryText = "today is sunny"; + MLInputDataset dataset = TextSimilarityInputDataSet.builder().textDocs(docs).queryText(queryText).build(); + input = new TextSimilarityMLInput(algorithm, dataset); + } + + @Test + public void testXContent_IsInternallyConsistent() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + input.toXContent(builder, ToXContent.EMPTY_PARAMS); + String jsonStr = builder.toString(); + XContentParser parser = XContentType.JSON.xContent() + .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); + + MLInput parsedInput = MLInput.parse(parser, input.getFunctionName().name()); + assert (parsedInput instanceof TextSimilarityMLInput); + TextSimilarityMLInput parsedTSMLI = (TextSimilarityMLInput) parsedInput; + List docs = ((TextSimilarityInputDataSet) parsedTSMLI.getInputDataset()).getTextDocs(); + String queryText = ((TextSimilarityInputDataSet) parsedTSMLI.getInputDataset()).getQueryText(); + assert (docs.size() == 2); + assert (docs.get(0).equals("That is a happy dog")); + assert (docs.get(1).equals("it's summer")); + assert (queryText.equals("today is sunny")); + } + + @Test + public void testXContent_String() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + input.toXContent(builder, ToXContent.EMPTY_PARAMS); + String jsonStr = builder.toString(); + assert (jsonStr.equals("{\"algorithm\":\"TEXT_SIMILARITY\",\"query_text\":\"today is sunny\",\"text_docs\":[\"That is a happy dog\",\"it's summer\"]}")); + } + + @Test + public void testParseJson() throws IOException { + String json = "{\"algorithm\":\"TEXT_SIMILARITY\",\"query_text\":\"today is sunny\",\"text_docs\":[\"That is a happy dog\",\"it's summer\"]}"; + XContentParser parser = XContentType.JSON.xContent() + .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, json); + parser.nextToken(); + + MLInput parsedInput = MLInput.parse(parser, input.getFunctionName().name()); + assert (parsedInput instanceof TextSimilarityMLInput); + TextSimilarityMLInput parsedTSMLI = (TextSimilarityMLInput) parsedInput; + List docs = ((TextSimilarityInputDataSet) parsedTSMLI.getInputDataset()).getTextDocs(); + String queryText = ((TextSimilarityInputDataSet) parsedTSMLI.getInputDataset()).getQueryText(); + assert (docs.size() == 2); + assert (docs.get(0).equals("That is a happy dog")); + assert (docs.get(1).equals("it's summer")); + assert (queryText.equals("today is sunny")); + } + + @Test + public void testParseJson_NoPairs_ThenFail() throws IOException { + String json = "{\"algorithm\":\"TEXT_SIMILARITY\",\"query_text\":\"today is sunny\",\"text_docs\":[]}"; + XContentParser parser = XContentType.JSON.xContent() + .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, json); + parser.nextToken(); + + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> MLInput.parse(parser, input.getFunctionName().name())); + assert (e.getMessage().equals("No text documents were provided")); + } + + @Test + public void testStreaming() throws IOException { + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + input.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + TextSimilarityMLInput newInput = new TextSimilarityMLInput(in); + List newPairs = ((TextSimilarityInputDataSet) newInput.getInputDataset()).getTextDocs(); + List oldPairs = ((TextSimilarityInputDataSet) input.getInputDataset()).getTextDocs(); + assert (newPairs.equals(oldPairs)); + } + +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModel.java new file mode 100644 index 0000000000..d3049c851a --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModel.java @@ -0,0 +1,69 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.engine.algorithms.text_similarity; + +import java.util.ArrayList; +import java.util.List; + +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.engine.algorithms.DLModel; +import org.opensearch.ml.engine.annotation.Function; + +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.translate.TranslateException; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorFactory; + +@Function(FunctionName.TEXT_SIMILARITY) +public class TextSimilarityCrossEncoderModel extends DLModel { + + @Override + public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException { + MLInputDataset inputDataSet = mlInput.getInputDataset(); + List tensorOutputs = new ArrayList<>(); + Output output; + TextSimilarityInputDataSet textSimInput = (TextSimilarityInputDataSet) inputDataSet; + String queryText = textSimInput.getQueryText(); + for (String doc : textSimInput.getTextDocs()) { + Input input = new Input(); + input.add(queryText); + input.add(doc); + output = getPredictor().predict(input); + ModelTensors outputTensors = ModelTensors.fromBytes(output.getData().getAsBytes()); + tensorOutputs.add(outputTensors); + } + return new ModelTensorOutput(tensorOutputs); + } + + @Override + public Translator getTranslator(String engine, MLModelConfig modelConfig) throws IllegalArgumentException { + return new TextSimilarityTranslator(); + } + + @Override + public TranslatorFactory getTranslatorFactory(String engine, MLModelConfig modelConfig) { + return null; + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityTranslator.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityTranslator.java new file mode 100644 index 0000000000..4967d2035b --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityTranslator.java @@ -0,0 +1,98 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.engine.algorithms.text_similarity; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.engine.algorithms.SentenceTransformerTranslator; + +import ai.djl.huggingface.tokenizers.Encoding; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.translate.TranslatorContext; + +public class TextSimilarityTranslator extends SentenceTransformerTranslator { + public final String SIMILARITY_NAME = "similarity"; + + @Override + public NDList processInput(TranslatorContext ctx, Input input) { + String sentence = input.getAsString(0); + String context = input.getAsString(1); + NDManager manager = ctx.getNDManager(); + NDList ndList = new NDList(); + Encoding encodings = tokenizer.encode(sentence, context); + long[] indices = encodings.getIds(); + long[] attentionMask = encodings.getAttentionMask(); + long[] tokenTypes = encodings.getTypeIds(); + + NDArray indicesArray = manager.create(indices); + indicesArray.setName("input_ids"); + + NDArray attentionMaskArray = manager.create(attentionMask); + attentionMaskArray.setName("attention_mask"); + + NDArray tokenTypeArray = manager.create(tokenTypes); + tokenTypeArray.setName("token_type_ids"); + + ndList.add(indicesArray); + ndList.add(attentionMaskArray); + ndList.add(tokenTypeArray); + return ndList; + } + + @Override + public Output processOutput(TranslatorContext ctx, NDList list) { + Output output = new Output(200, "OK"); + + List outputs = new ArrayList<>(); + Iterator iterator = list.iterator(); + while (iterator.hasNext()) { + NDArray ndArray = iterator.next(); + String name = SIMILARITY_NAME; + Number[] data = ndArray.toArray(); + long[] shape = ndArray.getShape().getShape(); + DataType dataType = ndArray.getDataType(); + MLResultDataType mlResultDataType = MLResultDataType.valueOf(dataType.name()); + ByteBuffer buffer = ndArray.toByteBuffer(); + ModelTensor tensor = ModelTensor + .builder() + .name(name) + .data(data) + .shape(shape) + .dataType(mlResultDataType) + .byteBuffer(buffer) + .build(); + outputs.add(tensor); + } + + ModelTensors modelTensorOutput = new ModelTensors(outputs); + output.add(modelTensorOutput.toBytes()); + return output; + } + +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModelTest.java new file mode 100644 index 0000000000..88e64e2517 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModelTest.java @@ -0,0 +1,319 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.engine.algorithms.text_similarity; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.engine.algorithms.DLModel.*; + +import java.io.File; +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; +import org.opensearch.ml.common.exception.MLException; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.engine.MLEngine; +import org.opensearch.ml.engine.ModelHelper; +import org.opensearch.ml.engine.encryptor.Encryptor; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.opensearch.ml.engine.utils.FileUtils; + +import ai.djl.Model; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.translate.TranslatorContext; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class TextSimilarityCrossEncoderModelTest { + + private File modelZipFile; + private MLModel model; + private ModelHelper modelHelper; + private Map params; + private TextSimilarityCrossEncoderModel textSimilarityCrossEncoderModel; + private Path mlCachePath; + private TextSimilarityInputDataSet inputDataSet; + private MLEngine mlEngine; + private Encryptor encryptor; + + @Before + public void setUp() throws URISyntaxException { + mlCachePath = Path.of("/tmp/ml_cache" + UUID.randomUUID()); + encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); + mlEngine = new MLEngine(mlCachePath, encryptor); + model = MLModel + .builder() + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .name("test_model_name") + .modelId("test_model_id") + .algorithm(FunctionName.TEXT_SIMILARITY) + .version("1.0.0") + .modelState(MLModelState.TRAINED) + .build(); + modelHelper = new ModelHelper(mlEngine); + params = new HashMap<>(); + modelZipFile = new File(getClass().getResource("TinyBERT-CE-torch_script.zip").toURI()); + params.put(MODEL_ZIP_FILE, modelZipFile); + params.put(MODEL_HELPER, modelHelper); + params.put(ML_ENGINE, mlEngine); + textSimilarityCrossEncoderModel = new TextSimilarityCrossEncoderModel(); + + inputDataSet = TextSimilarityInputDataSet + .builder() + .textDocs(Arrays.asList("That is a happy dog", "it's summer")) + .queryText("it's summer") + .build(); + } + + @Test + public void test_TextSimilarity_Translator_ProcessInput() throws URISyntaxException, IOException { + TextSimilarityTranslator textSimilarityTranslator = new TextSimilarityTranslator(); + TranslatorContext translatorContext = mock(TranslatorContext.class); + Model mlModel = mock(Model.class); + when(translatorContext.getModel()).thenReturn(mlModel); + when(mlModel.getModelPath()).thenReturn(Paths.get(getClass().getResource("../tokenize/tokenizer.json").toURI()).getParent()); + textSimilarityTranslator.prepare(translatorContext); + + NDManager manager = mock(NDManager.class); + when(translatorContext.getNDManager()).thenReturn(manager); + Input input = mock(Input.class); + String testSentence = "hello world"; + when(input.getAsString(0)).thenReturn(testSentence); + when(input.getAsString(1)).thenReturn(testSentence); + NDArray indiceNdArray = mock(NDArray.class); + when(indiceNdArray.toLongArray()).thenReturn(new long[] { 102l, 101l }); + when(manager.create((long[]) any())).thenReturn(indiceNdArray); + doNothing().when(indiceNdArray).setName(any()); + NDList outputList = textSimilarityTranslator.processInput(translatorContext, input); + assertEquals(3, outputList.size()); + Iterator iterator = outputList.iterator(); + while (iterator.hasNext()) { + NDArray ndArray = iterator.next(); + long[] output = ndArray.toLongArray(); + assertEquals(2, output.length); + } + } + + @Test + public void test_TextSimilarity_Translator_ProcessOutput() throws URISyntaxException, IOException { + TextSimilarityTranslator textSimilarityTranslator = new TextSimilarityTranslator(); + TranslatorContext translatorContext = mock(TranslatorContext.class); + Model mlModel = mock(Model.class); + when(translatorContext.getModel()).thenReturn(mlModel); + when(mlModel.getModelPath()).thenReturn(Paths.get(getClass().getResource("../tokenize/tokenizer.json").toURI()).getParent()); + textSimilarityTranslator.prepare(translatorContext); + + NDArray ndArray = mock(NDArray.class); + Shape shape = mock(Shape.class); + when(ndArray.nonzero()).thenReturn(ndArray); + when(ndArray.squeeze()).thenReturn(ndArray); + when(ndArray.getFloat(any())).thenReturn(1.0f); + when(ndArray.toArray()).thenReturn(new Number[] { 1.245f }); + when(ndArray.getName()).thenReturn("output"); + when(ndArray.getShape()).thenReturn(shape); + when(shape.getShape()).thenReturn(new long[] { 1 }); + when(ndArray.getDataType()).thenReturn(DataType.FLOAT32); + List ndArrayList = Collections.singletonList(ndArray); + NDList ndList = new NDList(ndArrayList); + Output output = textSimilarityTranslator.processOutput(translatorContext, ndList); + assertNotNull(output); + byte[] bytes = output.getData().getAsBytes(); + ModelTensors tensorOutput = ModelTensors.fromBytes(bytes); + List modelTensorsList = tensorOutput.getMlModelTensors(); + assertEquals(1, modelTensorsList.size()); + ModelTensor modelTensor = modelTensorsList.get(0); + assertEquals("similarity", modelTensor.getName()); + Number[] data = modelTensor.getData(); + assertEquals(1, data.length); + } + + @Test + public void initModel_predict_TorchScript_CrossEncoder() throws URISyntaxException { + textSimilarityCrossEncoderModel.initModel(model, params, encryptor); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(inputDataSet).build(); + ModelTensorOutput output = (ModelTensorOutput) textSimilarityCrossEncoderModel.predict(mlInput); + List mlModelOutputs = output.getMlModelOutputs(); + assertEquals(2, mlModelOutputs.size()); + for (int i = 0; i < mlModelOutputs.size(); i++) { + ModelTensors tensors = mlModelOutputs.get(i); + List mlModelTensors = tensors.getMlModelTensors(); + assertEquals(1, mlModelTensors.size()); + assertEquals(1, mlModelTensors.get(0).getData().length); + } + textSimilarityCrossEncoderModel.close(); + } + + @Test + public void initModel_predict_ONNX_CrossEncoder() throws URISyntaxException { + model = MLModel + .builder() + .modelFormat(MLModelFormat.ONNX) + .name("test_model_name") + .modelId("test_model_id") + .algorithm(FunctionName.TEXT_SIMILARITY) + .version("1.0.0") + .modelState(MLModelState.TRAINED) + .build(); + modelZipFile = new File(getClass().getResource("TinyBERT-CE-onnx.zip").toURI()); + params.put(MODEL_ZIP_FILE, modelZipFile); + + textSimilarityCrossEncoderModel.initModel(model, params, encryptor); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(inputDataSet).build(); + ModelTensorOutput output = (ModelTensorOutput) textSimilarityCrossEncoderModel.predict(mlInput); + List mlModelOutputs = output.getMlModelOutputs(); + assertEquals(2, mlModelOutputs.size()); + for (int i = 0; i < mlModelOutputs.size(); i++) { + ModelTensors tensors = mlModelOutputs.get(i); + List mlModelTensors = tensors.getMlModelTensors(); + assertEquals(1, mlModelTensors.size()); + assertEquals(1, mlModelTensors.get(0).getData().length); + } + textSimilarityCrossEncoderModel.close(); + } + + @Test + public void initModel_NullModelHelper() throws URISyntaxException { + Map params = new HashMap<>(); + params.put(MODEL_ZIP_FILE, new File(getClass().getResource("TinyBERT-CE-torch_script.zip").toURI())); + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> textSimilarityCrossEncoderModel.initModel(model, params, encryptor) + ); + assert (e.getMessage().equals("model helper is null")); + } + + @Test + public void initModel_NullMLEngine() throws URISyntaxException { + Map params = new HashMap<>(); + params.put(MODEL_ZIP_FILE, new File(getClass().getResource("TinyBERT-CE-torch_script.zip").toURI())); + params.put(MODEL_HELPER, modelHelper); + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> textSimilarityCrossEncoderModel.initModel(model, params, encryptor) + ); + assert (e.getMessage().equals("ML engine is null")); + } + + @Test + public void initModel_NullModelId() { + model.setModelId(null); + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> textSimilarityCrossEncoderModel.initModel(model, params, encryptor) + ); + assert (e.getMessage().equals("model id is null")); + } + + @Test + public void initModel_WrongModelFile() throws URISyntaxException { + Map params = new HashMap<>(); + params.put(MODEL_HELPER, modelHelper); + params.put(MODEL_ZIP_FILE, new File(getClass().getResource("../text_embedding/wrong_zip_with_2_pt_file.zip").toURI())); + params.put(ML_ENGINE, mlEngine); + MLException e = assertThrows(MLException.class, () -> textSimilarityCrossEncoderModel.initModel(model, params, encryptor)); + Throwable rootCause = e.getCause(); + assert (rootCause instanceof IllegalArgumentException); + assert (rootCause.getMessage().equals("found multiple models")); + } + + @Test + public void initModel_WrongFunctionName() { + MLModel mlModel = model.toBuilder().algorithm(FunctionName.KMEANS).build(); + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> textSimilarityCrossEncoderModel.initModel(mlModel, params, encryptor) + ); + assert (e.getMessage().equals("wrong function name")); + } + + @Test + public void predict_NullModelHelper() { + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> textSimilarityCrossEncoderModel + .predict(MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(inputDataSet).build()) + ); + assert (e.getMessage().equals("model not deployed")); + } + + @Test + public void predict_NullModelId() { + model.setModelId(null); + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> textSimilarityCrossEncoderModel.initModel(model, params, encryptor) + ); + assert (e.getMessage().equals("model id is null")); + IllegalArgumentException e2 = assertThrows( + IllegalArgumentException.class, + () -> textSimilarityCrossEncoderModel + .predict(MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(inputDataSet).build()) + ); + assert (e2.getMessage().equals("model not deployed")); + } + + @Test + public void predict_AfterModelClosed() { + textSimilarityCrossEncoderModel.initModel(model, params, encryptor); + textSimilarityCrossEncoderModel.close(); + MLException e = assertThrows( + MLException.class, + () -> textSimilarityCrossEncoderModel + .predict(MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(inputDataSet).build()) + ); + log.info(e.getMessage()); + assert (e.getMessage().startsWith("Failed to inference TEXT_SIMILARITY")); + } + + @After + public void tearDown() { + FileUtils.deleteFileQuietly(mlCachePath); + } + +} diff --git a/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_similarity/TinyBERT-CE-onnx.zip b/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_similarity/TinyBERT-CE-onnx.zip new file mode 100644 index 0000000000..fd8b7841e1 Binary files /dev/null and b/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_similarity/TinyBERT-CE-onnx.zip differ diff --git a/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_similarity/TinyBERT-CE-torch_script.zip b/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_similarity/TinyBERT-CE-torch_script.zip new file mode 100644 index 0000000000..c4af9a218c Binary files /dev/null and b/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_similarity/TinyBERT-CE-torch_script.zip differ