Skip to content

Commit

Permalink
1799 asymmetric embeddings (opensearch-project#2123)
Browse files Browse the repository at this point in the history
* add support for asymmetric embeddings

Signed-off-by: br3no <[email protected]>

* fix NPE with sparse models

Signed-off-by: br3no <[email protected]>

* after review

Signed-off-by: br3no <[email protected]>

* improving code coverage

Signed-off-by: br3no <[email protected]>

* improving javadocs

Signed-off-by: br3no <[email protected]>

* adding unit-test for AsymmetricTextEmbeddingParameters

Signed-off-by: br3no <[email protected]>

* after latest review

Signed-off-by: br3no <[email protected]>

---------

Signed-off-by: br3no <[email protected]>
  • Loading branch information
br3no authored and austintlee committed Mar 18, 2024
1 parent c4286ff commit 1cbaed0
Show file tree
Hide file tree
Showing 11 changed files with 615 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.dataset;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

import java.io.IOException;
import java.util.Locale;

import org.opensearch.core.ParseField;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.annotation.MLAlgoParameter;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;

import lombok.Builder;
import lombok.Data;

/**
* This class defines the modes of operation of an asymmetric text embedding model.
* Asymmetric embedding models treat the input text differently, depending on whether it is a
* passage or a query. One example asymmetric model, that requires different prefixes is e5
* (cf. https://arxiv.org/pdf/2212.03533.pdf).
* <p>
* Use this parameter only if the model is asymmetric and has been registered with the corresponding
* `query_prefix` and `passage_prefix` configuration parameters.
*/
@Data
@MLAlgoParameter(algorithms = { FunctionName.TEXT_EMBEDDING })
public class AsymmetricTextEmbeddingParameters implements MLAlgoParams {

public enum EmbeddingContentType {
QUERY,
PASSAGE
}

public static final String PARSE_FIELD_NAME = FunctionName.TEXT_EMBEDDING.name();
public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry(
MLAlgoParams.class,
new ParseField(PARSE_FIELD_NAME),
it -> parse(it)
);

@Builder(toBuilder = true)
public AsymmetricTextEmbeddingParameters(EmbeddingContentType embeddingContentType) {
this.embeddingContentType = embeddingContentType;
}

public AsymmetricTextEmbeddingParameters(StreamInput in) throws IOException {
this.embeddingContentType = EmbeddingContentType.valueOf(in.readOptionalString());
}

public static MLAlgoParams parse(XContentParser parser) throws IOException {
EmbeddingContentType embeddingContentType = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case EMBEDDING_CONTENT_TYPE_FIELD:
String contentType = parser.text();
embeddingContentType = EmbeddingContentType.valueOf(contentType.toUpperCase(Locale.ROOT));
break;
default:
parser.skipChildren();
break;
}
}
return new AsymmetricTextEmbeddingParameters(embeddingContentType);
}

public static final String EMBEDDING_CONTENT_TYPE_FIELD = "content_type";

// The type of the content to be embedded
private EmbeddingContentType embeddingContentType;

@Override
public int getVersion() {
return 1;
}

@Override
public String getWriteableName() {
return PARSE_FIELD_NAME;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(embeddingContentType.name());
}

@Override
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
xContentBuilder.startObject();
if (embeddingContentType != null) {
xContentBuilder.field(EMBEDDING_CONTENT_TYPE_FIELD, embeddingContentType.name());
}
xContentBuilder.endObject();
return xContentBuilder;
}

public EmbeddingContentType getEmbeddingContentType() {
return embeddingContentType;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,25 @@ public class TextEmbeddingModelConfig extends MLModelConfig {
public static final String POOLING_MODE_FIELD = "pooling_mode";
public static final String NORMALIZE_RESULT_FIELD = "normalize_result";
public static final String MODEL_MAX_LENGTH_FIELD = "model_max_length";
public static final String QUERY_PREFIX = "query_prefix";
public static final String PASSAGE_PREFIX = "passage_prefix";

private final Integer embeddingDimension;
private final FrameworkType frameworkType;
private final PoolingMode poolingMode;
private final boolean normalizeResult;
private final Integer modelMaxLength;
private final String queryPrefix;
private final String passagePrefix;

public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, FrameworkType frameworkType, String allConfig,
PoolingMode poolingMode, boolean normalizeResult, Integer modelMaxLength) {
this(modelType, embeddingDimension, frameworkType, allConfig, poolingMode, normalizeResult, modelMaxLength, null, null);
}

@Builder(toBuilder = true)
public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, FrameworkType frameworkType, String allConfig,
PoolingMode poolingMode, boolean normalizeResult, Integer modelMaxLength) {
PoolingMode poolingMode, boolean normalizeResult, Integer modelMaxLength, String queryPrefix, String passagePrefix) {
super(modelType, allConfig);
if (embeddingDimension == null) {
throw new IllegalArgumentException("embedding dimension is null");
Expand All @@ -59,6 +68,8 @@ public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, Fr
this.poolingMode = poolingMode;
this.normalizeResult = normalizeResult;
this.modelMaxLength = modelMaxLength;
this.queryPrefix = queryPrefix;
this.passagePrefix = passagePrefix;
}

public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOException {
Expand All @@ -69,6 +80,8 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc
PoolingMode poolingMode = null;
boolean normalizeResult = false;
Integer modelMaxLength = null;
String queryPrefix = null;
String passagePrefix = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -97,12 +110,18 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc
case MODEL_MAX_LENGTH_FIELD:
modelMaxLength = parser.intValue();
break;
case QUERY_PREFIX:
queryPrefix = parser.text();
break;
case PASSAGE_PREFIX:
passagePrefix = parser.text();
break;
default:
parser.skipChildren();
break;
}
}
return new TextEmbeddingModelConfig(modelType, embeddingDimension, frameworkType, allConfig, poolingMode, normalizeResult, modelMaxLength);
return new TextEmbeddingModelConfig(modelType, embeddingDimension, frameworkType, allConfig, poolingMode, normalizeResult, modelMaxLength, queryPrefix, passagePrefix);
}

@Override
Expand All @@ -121,6 +140,8 @@ public TextEmbeddingModelConfig(StreamInput in) throws IOException{
}
normalizeResult = in.readBoolean();
modelMaxLength = in.readOptionalInt();
queryPrefix = in.readOptionalString();
passagePrefix = in.readOptionalString();
}

@Override
Expand All @@ -136,6 +157,8 @@ public void writeTo(StreamOutput out) throws IOException {
}
out.writeBoolean(normalizeResult);
out.writeOptionalInt(modelMaxLength);
out.writeOptionalString(queryPrefix);
out.writeOptionalString(passagePrefix);
}

@Override
Expand All @@ -162,6 +185,12 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (normalizeResult) {
builder.field(NORMALIZE_RESULT_FIELD, normalizeResult);
}
if (queryPrefix != null) {
builder.field(QUERY_PREFIX, queryPrefix);
}
if (passagePrefix != null) {
builder.field(PASSAGE_PREFIX, passagePrefix);
}
builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package org.opensearch.ml.common.dataset;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.TestHelper;

import java.io.IOException;
import java.util.function.Function;
import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters.EmbeddingContentType;

import static org.junit.Assert.assertEquals;
import static org.opensearch.ml.common.TestHelper.contentObjectToString;
import static org.opensearch.ml.common.TestHelper.testParseFromString;

public class AsymmetricTextEmbeddingParametersTest {

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

AsymmetricTextEmbeddingParameters params;
private Function<XContentParser, AsymmetricTextEmbeddingParameters> function = parser -> {
try {
return (AsymmetricTextEmbeddingParameters) AsymmetricTextEmbeddingParameters.parse(parser);
} catch (IOException e) {
throw new RuntimeException("failed to parse AsymmetricTextEmbeddingParameters", e);
}
};

@Before
public void setUp() {
params = AsymmetricTextEmbeddingParameters.builder()
.embeddingContentType(EmbeddingContentType.QUERY)
.build();
}

@Test
public void parse_AsymmetricTextEmbeddingParameters() throws IOException {
TestHelper.testParse(params, function);
}

@Test
public void parse_AsymmetricTextEmbeddingParameters_Passage() throws IOException {
String paramsStr = contentObjectToString(params);
testParseFromString(params, paramsStr.replace("QUERY", "PASSAGE"), function);
}

@Test
public void parse_AsymmetricTextEmbeddingParameters_Invalid() throws IOException {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("No enum constant org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters.EmbeddingContentType.FU");
String paramsStr = contentObjectToString(params);
testParseFromString(params, paramsStr.replace("QUERY","fu"), function);
}

@Test
public void parse_EmptyAsymmetricTextEmbeddingParameters() throws IOException {
TestHelper.testParse(AsymmetricTextEmbeddingParameters.builder().build(), function);
}

@Test
public void readInputStream_Success() throws IOException {
readInputStream(params);
}

@Test
public void readInputStream_Success_EmptyParams() throws IOException {
readInputStream(AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build());
}

private void readInputStream(AsymmetricTextEmbeddingParameters params) throws IOException {
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
params.writeTo(bytesStreamOutput);

StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
AsymmetricTextEmbeddingParameters parsedParams = new AsymmetricTextEmbeddingParameters(streamInput);
assertEquals(params, parsedParams);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ public class TextDocsMLInputTest {
@Before
public void setUp() throws Exception {
ModelResultFilter resultFilter = ModelResultFilter.builder().returnBytes(true).returnNumber(true)
.targetResponse(Arrays.asList("field1")).targetResponsePositions(Arrays.asList(2)).build();
.targetResponse(Arrays.asList("field1")).targetResponsePositions(Arrays.asList(2)).build();
MLInputDataset inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList("doc1", "doc2"))
.resultFilter(resultFilter).build();
.resultFilter(resultFilter).build();
input = new TextDocsMLInput(algorithm, inputDataset);
}

Expand All @@ -68,8 +68,8 @@ public void parseTextDocsMLInput_NewWay() throws IOException {

private void parseMLInput(String jsonStr, int docSize) throws IOException {
XContentParser parser = XContentType.JSON.xContent()
.createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
.createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();

MLInput parsedInput = MLInput.parse(parser, input.getFunctionName().name());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ public void setUp() {
.allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}")
.frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS)
.embeddingDimension(100)
.passagePrefix("passage: ")
.queryPrefix("query: ")
.build();
function = parser -> {
try {
Expand All @@ -51,7 +53,7 @@ public void toXContent() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
config.toXContent(builder, EMPTY_PARAMS);
String configContent = TestHelper.xContentBuilderToString(builder);
assertEquals("{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}", configContent);
assertEquals("{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\",\"query_prefix\":\"query: \",\"passage_prefix\":\"passage: \"}", configContent);
}

@Test
Expand Down Expand Up @@ -83,7 +85,7 @@ public void nullFields_FrameworkType() {

@Test
public void parse() throws IOException {
String content = "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}";
String content = "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\",\"query_prefix\":\"query: \",\"passage_prefix\":\"passage: \"}";
TestHelper.testParseFromString(config, content, function);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ public void downloadPrebuiltModelConfig(
case TextEmbeddingModelConfig.MODEL_MAX_LENGTH_FIELD:
configBuilder.modelMaxLength(((Double) configEntry.getValue()).intValue());
break;
case TextEmbeddingModelConfig.QUERY_PREFIX:
configBuilder.queryPrefix(configEntry.getValue().toString());
break;
case TextEmbeddingModelConfig.PASSAGE_PREFIX:
configBuilder.passagePrefix(configEntry.getValue().toString());
break;
default:
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ public abstract class DLModel implements Predictable {
protected Device[] devices;
protected AtomicInteger nextDevice = new AtomicInteger(0);

protected MLModelConfig modelConfig;

@Override
public MLOutput predict(MLInput mlInput, MLModel model) {
throw new IllegalArgumentException("model not deployed");
Expand Down Expand Up @@ -183,6 +185,7 @@ protected void doLoadModel(
IOException,
TranslateException {
devices = Engine.getEngine(engine).getDevices();
this.modelConfig = modelConfig;
for (int i = 0; i < devices.length; i++) {
log.debug("load model {} to device {}: {}", modelId, i, devices[i]);
ZooModel<Input, Output> model;
Expand Down
Loading

0 comments on commit 1cbaed0

Please sign in to comment.