Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

asymmetric embeddings #2123

Merged
merged 7 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.dataset;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

import java.io.IOException;
import java.util.Locale;

import org.opensearch.core.ParseField;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.annotation.MLAlgoParameter;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;

import lombok.Builder;
import lombok.Data;

/**
* This class defines the modes of operation of an asymmetric text embedding model.
* Asymmetric embedding models treat the input text differently, depending on whether it is a
* passage or a query. One example asymmetric model, that requires different prefixes is e5
* (cf. https://arxiv.org/pdf/2212.03533.pdf).
* <p>
* Use this parameter only if the model is asymmetric and has been registered with the corresponding
* `query_prefix` and `passage_prefix` configuration parameters.
*/
@Data
@MLAlgoParameter(algorithms = { FunctionName.TEXT_EMBEDDING })
public class AsymmetricTextEmbeddingParameters implements MLAlgoParams {
br3no marked this conversation as resolved.
Show resolved Hide resolved

public enum EmbeddingContentType {
QUERY,
PASSAGE
}

public static final String PARSE_FIELD_NAME = FunctionName.TEXT_EMBEDDING.name();
public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry(
MLAlgoParams.class,
new ParseField(PARSE_FIELD_NAME),
it -> parse(it)

Check warning on line 48 in common/src/main/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParameters.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParameters.java#L48

Added line #L48 was not covered by tests
);

@Builder(toBuilder = true)
public AsymmetricTextEmbeddingParameters(EmbeddingContentType embeddingContentType) {
this.embeddingContentType = embeddingContentType;
}

public AsymmetricTextEmbeddingParameters(StreamInput in) throws IOException {
this.embeddingContentType = EmbeddingContentType.valueOf(in.readOptionalString());
}

public static MLAlgoParams parse(XContentParser parser) throws IOException {
EmbeddingContentType embeddingContentType = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case EMBEDDING_CONTENT_TYPE_FIELD:
String contentType = parser.text();
embeddingContentType = EmbeddingContentType.valueOf(contentType.toUpperCase(Locale.ROOT));
break;
default:
parser.skipChildren();

Check warning on line 74 in common/src/main/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParameters.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParameters.java#L74

Added line #L74 was not covered by tests
break;
}
}
return new AsymmetricTextEmbeddingParameters(embeddingContentType);
}

public static final String EMBEDDING_CONTENT_TYPE_FIELD = "content_type";

// The type of the content to be embedded
private EmbeddingContentType embeddingContentType;

@Override
public int getVersion() {
return 1;

Check warning on line 88 in common/src/main/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParameters.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParameters.java#L88

Added line #L88 was not covered by tests
}

@Override
public String getWriteableName() {
return PARSE_FIELD_NAME;

Check warning on line 93 in common/src/main/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParameters.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParameters.java#L93

Added line #L93 was not covered by tests
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(embeddingContentType.name());
}

@Override
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
xContentBuilder.startObject();
if (embeddingContentType != null) {
xContentBuilder.field(EMBEDDING_CONTENT_TYPE_FIELD, embeddingContentType.name());
}
xContentBuilder.endObject();
return xContentBuilder;
}

public EmbeddingContentType getEmbeddingContentType() {
return embeddingContentType;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,25 @@ public class TextEmbeddingModelConfig extends MLModelConfig {
public static final String POOLING_MODE_FIELD = "pooling_mode";
public static final String NORMALIZE_RESULT_FIELD = "normalize_result";
public static final String MODEL_MAX_LENGTH_FIELD = "model_max_length";
public static final String QUERY_PREFIX = "query_prefix";
public static final String PASSAGE_PREFIX = "passage_prefix";

private final Integer embeddingDimension;
private final FrameworkType frameworkType;
private final PoolingMode poolingMode;
private final boolean normalizeResult;
private final Integer modelMaxLength;
private final String queryPrefix;
private final String passagePrefix;

public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, FrameworkType frameworkType, String allConfig,
PoolingMode poolingMode, boolean normalizeResult, Integer modelMaxLength) {
this(modelType, embeddingDimension, frameworkType, allConfig, poolingMode, normalizeResult, modelMaxLength, null, null);
}

@Builder(toBuilder = true)
public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, FrameworkType frameworkType, String allConfig,
PoolingMode poolingMode, boolean normalizeResult, Integer modelMaxLength) {
PoolingMode poolingMode, boolean normalizeResult, Integer modelMaxLength, String queryPrefix, String passagePrefix) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not add a check function here to ensure at least one of the new parameters is not Null? It looks like this check can avoid problems later in loading the asymmetric model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TextEmbeddingModelConfig is used for both asymmetric and symmetric models. Symmetric models don't have prefixes. The alternative would be to create a new TextEmbeddingModelConfig class for asymmetric models. I don't think this would be a cleaner solution, tbh.

super(modelType, allConfig);
if (embeddingDimension == null) {
throw new IllegalArgumentException("embedding dimension is null");
Expand All @@ -59,6 +68,8 @@ public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, Fr
this.poolingMode = poolingMode;
this.normalizeResult = normalizeResult;
this.modelMaxLength = modelMaxLength;
this.queryPrefix = queryPrefix;
this.passagePrefix = passagePrefix;
}

public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOException {
Expand All @@ -69,6 +80,8 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc
PoolingMode poolingMode = null;
boolean normalizeResult = false;
Integer modelMaxLength = null;
String queryPrefix = null;
String passagePrefix = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -97,12 +110,18 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc
case MODEL_MAX_LENGTH_FIELD:
modelMaxLength = parser.intValue();
break;
case QUERY_PREFIX:
queryPrefix = parser.text();
break;
br3no marked this conversation as resolved.
Show resolved Hide resolved
case PASSAGE_PREFIX:
passagePrefix = parser.text();
break;
br3no marked this conversation as resolved.
Show resolved Hide resolved
default:
parser.skipChildren();
break;
}
}
return new TextEmbeddingModelConfig(modelType, embeddingDimension, frameworkType, allConfig, poolingMode, normalizeResult, modelMaxLength);
return new TextEmbeddingModelConfig(modelType, embeddingDimension, frameworkType, allConfig, poolingMode, normalizeResult, modelMaxLength, queryPrefix, passagePrefix);
}

@Override
Expand All @@ -121,6 +140,8 @@ public TextEmbeddingModelConfig(StreamInput in) throws IOException{
}
normalizeResult = in.readBoolean();
modelMaxLength = in.readOptionalInt();
queryPrefix = in.readOptionalString();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can focus more on the feature itself in this PR. But later when we release this in new version. We should add version check. @b4sjoo Can you help when release new version.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both queryPrefix and passagePrefix are optional, so this shouldn't break compatibility. Are you sure there is need for a version check?

passagePrefix = in.readOptionalString();
}

@Override
Expand All @@ -136,6 +157,8 @@ public void writeTo(StreamOutput out) throws IOException {
}
out.writeBoolean(normalizeResult);
out.writeOptionalInt(modelMaxLength);
out.writeOptionalString(queryPrefix);
out.writeOptionalString(passagePrefix);
}

@Override
Expand All @@ -162,6 +185,12 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (normalizeResult) {
builder.field(NORMALIZE_RESULT_FIELD, normalizeResult);
}
if (queryPrefix != null) {
builder.field(QUERY_PREFIX, queryPrefix);
}
if (passagePrefix != null) {
builder.field(PASSAGE_PREFIX, passagePrefix);
}
builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package org.opensearch.ml.common.dataset;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.TestHelper;

import java.io.IOException;
import java.util.function.Function;
import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters.EmbeddingContentType;

import static org.junit.Assert.assertEquals;
import static org.opensearch.ml.common.TestHelper.contentObjectToString;
import static org.opensearch.ml.common.TestHelper.testParseFromString;

public class AsymmetricTextEmbeddingParametersTest {

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

AsymmetricTextEmbeddingParameters params;
private Function<XContentParser, AsymmetricTextEmbeddingParameters> function = parser -> {
try {
return (AsymmetricTextEmbeddingParameters) AsymmetricTextEmbeddingParameters.parse(parser);
} catch (IOException e) {
throw new RuntimeException("failed to parse AsymmetricTextEmbeddingParameters", e);
}
};

@Before
public void setUp() {
params = AsymmetricTextEmbeddingParameters.builder()
.embeddingContentType(EmbeddingContentType.QUERY)
.build();
}

@Test
public void parse_AsymmetricTextEmbeddingParameters() throws IOException {
TestHelper.testParse(params, function);
}

@Test
public void parse_AsymmetricTextEmbeddingParameters_Passage() throws IOException {
String paramsStr = contentObjectToString(params);
testParseFromString(params, paramsStr.replace("QUERY", "PASSAGE"), function);
}

@Test
public void parse_AsymmetricTextEmbeddingParameters_Invalid() throws IOException {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("No enum constant org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters.EmbeddingContentType.FU");
String paramsStr = contentObjectToString(params);
testParseFromString(params, paramsStr.replace("QUERY","fu"), function);
}

@Test
public void parse_EmptyAsymmetricTextEmbeddingParameters() throws IOException {
TestHelper.testParse(AsymmetricTextEmbeddingParameters.builder().build(), function);
}

@Test
public void readInputStream_Success() throws IOException {
readInputStream(params);
}

@Test
public void readInputStream_Success_EmptyParams() throws IOException {
readInputStream(AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build());
}

private void readInputStream(AsymmetricTextEmbeddingParameters params) throws IOException {
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
params.writeTo(bytesStreamOutput);

StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
AsymmetricTextEmbeddingParameters parsedParams = new AsymmetricTextEmbeddingParameters(streamInput);
assertEquals(params, parsedParams);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ public class TextDocsMLInputTest {
@Before
public void setUp() throws Exception {
ModelResultFilter resultFilter = ModelResultFilter.builder().returnBytes(true).returnNumber(true)
.targetResponse(Arrays.asList("field1")).targetResponsePositions(Arrays.asList(2)).build();
.targetResponse(Arrays.asList("field1")).targetResponsePositions(Arrays.asList(2)).build();
MLInputDataset inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList("doc1", "doc2"))
.resultFilter(resultFilter).build();
.resultFilter(resultFilter).build();
input = new TextDocsMLInput(algorithm, inputDataset);
}

Expand All @@ -68,8 +68,8 @@ public void parseTextDocsMLInput_NewWay() throws IOException {

private void parseMLInput(String jsonStr, int docSize) throws IOException {
XContentParser parser = XContentType.JSON.xContent()
.createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
.createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();

MLInput parsedInput = MLInput.parse(parser, input.getFunctionName().name());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ public void setUp() {
.allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}")
.frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS)
.embeddingDimension(100)
.passagePrefix("passage: ")
.queryPrefix("query: ")
.build();
function = parser -> {
try {
Expand All @@ -51,7 +53,7 @@ public void toXContent() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
config.toXContent(builder, EMPTY_PARAMS);
String configContent = TestHelper.xContentBuilderToString(builder);
assertEquals("{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}", configContent);
assertEquals("{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\",\"query_prefix\":\"query: \",\"passage_prefix\":\"passage: \"}", configContent);
}

@Test
Expand Down Expand Up @@ -83,7 +85,7 @@ public void nullFields_FrameworkType() {

@Test
public void parse() throws IOException {
String content = "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}";
String content = "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\",\"query_prefix\":\"query: \",\"passage_prefix\":\"passage: \"}";
TestHelper.testParseFromString(config, content, function);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@
case TextEmbeddingModelConfig.MODEL_MAX_LENGTH_FIELD:
configBuilder.modelMaxLength(((Double) configEntry.getValue()).intValue());
break;
case TextEmbeddingModelConfig.QUERY_PREFIX:
configBuilder.queryPrefix(configEntry.getValue().toString());
break;

Check warning on line 140 in ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java

View check run for this annotation

Codecov / codecov/patch

ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java#L139-L140

Added lines #L139 - L140 were not covered by tests
br3no marked this conversation as resolved.
Show resolved Hide resolved
case TextEmbeddingModelConfig.PASSAGE_PREFIX:
configBuilder.passagePrefix(configEntry.getValue().toString());
break;

Check warning on line 143 in ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java

View check run for this annotation

Codecov / codecov/patch

ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java#L142-L143

Added lines #L142 - L143 were not covered by tests
default:
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ public abstract class DLModel implements Predictable {
protected Device[] devices;
protected AtomicInteger nextDevice = new AtomicInteger(0);

protected MLModelConfig modelConfig;

@Override
public MLOutput predict(MLInput mlInput, MLModel model) {
throw new IllegalArgumentException("model not deployed");
Expand Down Expand Up @@ -183,6 +185,7 @@ protected void doLoadModel(
IOException,
TranslateException {
devices = Engine.getEngine(engine).getDevices();
this.modelConfig = modelConfig;
for (int i = 0; i < devices.length; i++) {
log.debug("load model {} to device {}: {}", modelId, i, devices[i]);
ZooModel<Input, Output> model;
Expand Down
Loading
Loading