Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Add cross encoder support #1739

Merged
merged 1 commit into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions common/src/main/java/org/opensearch/ml/common/FunctionName.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

package org.opensearch.ml.common;

import java.util.HashSet;
import java.util.Set;

public enum FunctionName {
LINEAR_REGRESSION,
KMEANS,
Expand All @@ -17,6 +20,7 @@ public enum FunctionName {
RCF_SUMMARIZE,
LOGISTIC_REGRESSION,
TEXT_EMBEDDING,
TEXT_SIMILARITY,
SPARSE_ENCODING,
SPARSE_TOKENIZE,
METRICS_CORRELATION,
Expand All @@ -30,14 +34,18 @@ public static FunctionName from(String value) {
}
}

private static final HashSet<FunctionName> DL_MODELS = new HashSet<>(Set.of(
TEXT_EMBEDDING,
TEXT_SIMILARITY,
SPARSE_ENCODING,
SPARSE_TOKENIZE
));

/**
* Check if model is deep learning model.
* @return true for deep learning model.
*/
public static boolean isDLModel(FunctionName functionName) {
if (functionName == TEXT_EMBEDDING || functionName == SPARSE_ENCODING || functionName == SPARSE_TOKENIZE) {
return true;
}
return false;
return DL_MODELS.contains(functionName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ public enum MLInputDataType {
SEARCH_QUERY,
DATA_FRAME,
TEXT_DOCS,
TEXT_SIMILARITY,
REMOTE
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.common.dataset;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.annotation.InputDataSet;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.experimental.FieldDefaults;

@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@InputDataSet(MLInputDataType.TEXT_SIMILARITY)
public class TextSimilarityInputDataSet extends MLInputDataset {

List<String> textDocs;

String queryText;

@Builder(toBuilder = true)
public TextSimilarityInputDataSet(String queryText, List<String> textDocs) {
super(MLInputDataType.TEXT_SIMILARITY);
Objects.requireNonNull(textDocs);
Objects.requireNonNull(queryText);
if(textDocs.isEmpty()) {
throw new IllegalArgumentException("No text documents were provided");
}
this.textDocs = textDocs;
this.queryText = queryText;
}

public TextSimilarityInputDataSet(StreamInput in) throws IOException {
super(MLInputDataType.TEXT_SIMILARITY);
this.queryText = in.readString();
int size = in.readInt();
this.textDocs = new ArrayList<String>();
for(int i = 0; i < size; i++) {
String context = in.readString();
this.textDocs.add(context);
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(queryText);
out.writeInt(this.textDocs.size());
for (String doc : this.textDocs) {
out.writeString(doc);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
Expand All @@ -21,6 +22,7 @@
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.search.builder.SearchSourceBuilder;

Expand Down Expand Up @@ -55,6 +57,8 @@ public class MLInput implements Input {
public static final String TARGET_RESPONSE_POSITIONS_FIELD = "target_response_positions";
// Input text sentences for text embedding model
public static final String TEXT_DOCS_FIELD = "text_docs";
// Input query text to compare against for text similarity model
public static final String QUERY_TEXT_FIELD = "query_text";

// Algorithm name
protected FunctionName algorithm;
Expand Down Expand Up @@ -157,6 +161,20 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(TARGET_RESPONSE_POSITIONS_FIELD, targetPositions.toArray(new Integer[0]));
}
}
break;
case TEXT_SIMILARITY:
TextSimilarityInputDataSet ds = (TextSimilarityInputDataSet) this.inputDataset;
List<String> tdocs = ds.getTextDocs();
String queryText = ds.getQueryText();
builder.field(QUERY_TEXT_FIELD, queryText);
if (tdocs != null && !tdocs.isEmpty()) {
builder.startArray(TEXT_DOCS_FIELD);
for(String d : tdocs) {
builder.value(d);
}
builder.endArray();
}
break;
default:
break;
}
Expand Down Expand Up @@ -186,6 +204,7 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws
List<String> targetResponse = new ArrayList<>();
List<Integer> targetResponsePositions = new ArrayList<>();
List<String> textDocs = new ArrayList<>();
String queryText = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -233,6 +252,9 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws
textDocs.add(parser.text());
}
break;
case QUERY_TEXT_FIELD:
queryText = parser.text();
break;
default:
parser.skipChildren();
break;
Expand All @@ -243,6 +265,9 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws
ModelResultFilter filter = new ModelResultFilter(returnBytes, returnNumber, targetResponse, targetResponsePositions);
inputDataSet = new TextDocsInputDataSet(textDocs, filter);
}
if (algorithm == FunctionName.TEXT_SIMILARITY) {
inputDataSet = new TextSimilarityInputDataSet(queryText, textDocs);
}
return new MLInput(algorithm, mlParameters, searchSourceBuilder, sourceIndices, dataFrame, inputDataSet);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.common.input.nlp;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.input.MLInput;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;


/**
* MLInput which supports a text similarity algorithm
* Inputs are a query and a list of texts. Outputs are real numbers
* Use this for Cross Encoder models
*/
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.TEXT_SIMILARITY})
public class TextSimilarityMLInput extends MLInput {

public TextSimilarityMLInput(FunctionName algorithm, MLInputDataset dataset) {
super(algorithm, null, dataset);
}

public TextSimilarityMLInput(StreamInput in) throws IOException {
super(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ALGORITHM_FIELD, algorithm.name());
if(parameters != null) {
builder.field(ML_PARAMETERS_FIELD, parameters);
}
if(inputDataset != null) {
TextSimilarityInputDataSet ds = (TextSimilarityInputDataSet) this.inputDataset;
List<String> docs = ds.getTextDocs();
String queryText = ds.getQueryText();
builder.field(QUERY_TEXT_FIELD, queryText);
if (docs != null && !docs.isEmpty()) {
builder.startArray(TEXT_DOCS_FIELD);
for(String d : docs) {
builder.value(d);
}
builder.endArray();
}
}
builder.endObject();
return builder;
}

public TextSimilarityMLInput(XContentParser parser, FunctionName functionName) throws IOException {
super();
this.algorithm = functionName;
List<String> docs = new ArrayList<>();
String queryText = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case TEXT_DOCS_FIELD:
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
String context = parser.text();
docs.add(context);
}
break;
case QUERY_TEXT_FIELD:
queryText = parser.text();
default:
parser.skipChildren();
break;
}
}
if(docs.isEmpty()) {
throw new IllegalArgumentException("No text documents were provided");
}
if(queryText == null) {
throw new IllegalArgumentException("No query text was provided");
}
inputDataset = new TextSimilarityInputDataSet(queryText, docs);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.common.dataset;

import static org.junit.Assert.assertThrows;

import java.io.IOException;
import java.util.List;

import org.junit.Test;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.BytesStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;

public class TextSimilarityInputDatasetTest {

@Test
public void testStreaming() throws IOException {
List<String> docs = List.of("That is a happy dog", "it's summer");
String queryText = "today is sunny";
TextSimilarityInputDataSet dataset = TextSimilarityInputDataSet.builder().queryText(queryText).textDocs(docs).build();
BytesStreamOutput outbytes = new BytesStreamOutput();
StreamOutput osso = new OutputStreamStreamOutput(outbytes);
dataset.writeTo(osso);
StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes()));
TextSimilarityInputDataSet newDs = (TextSimilarityInputDataSet) MLInputDataset.fromStream(in);
assert (dataset.getTextDocs().equals(newDs.getTextDocs()));
assert (dataset.getQueryText().equals(newDs.getQueryText()));
}

@Test
public void noPairs_ThenFail() {
List<String> docs = List.of();
String queryText = "today is sunny";
IllegalArgumentException e = assertThrows(IllegalArgumentException.class,
() -> TextSimilarityInputDataSet.builder().textDocs(docs).queryText(queryText).build());
assert (e.getMessage().equals("No text documents were provided"));
}

@Test
public void noQuery_ThenFail() {
List<String> docs = List.of("That is a happy dog", "it's summer");
String queryText = null;
assertThrows(NullPointerException.class,
() -> TextSimilarityInputDataSet.builder().textDocs(docs).queryText(queryText).build());
}
}
Loading
Loading