Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

asymmetric embeddings #2123

Merged
merged 7 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.output.model.ModelResultFilter;

import java.io.IOException;
Expand All @@ -22,7 +23,7 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

/**
* ML input class which supports a list fo text docs.
* ML input class which supports a list of text docs.
* This class can be used for TEXT_EMBEDDING model.
*/
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.TEXT_EMBEDDING, FunctionName.SPARSE_ENCODING, FunctionName.SPARSE_TOKENIZE})
Expand Down Expand Up @@ -124,6 +125,9 @@ public TextDocsMLInput(XContentParser parser, FunctionName functionName) throws
case RESULT_FILTER_FIELD:
resultFilter = ModelResultFilter.parse(parser);
break;
case ML_PARAMETERS_FIELD:
br3no marked this conversation as resolved.
Show resolved Hide resolved
this.parameters = parser.namedObject(MLAlgoParams.class, functionName.name(), null);
break;
default:
parser.skipChildren();
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,25 @@ public class TextEmbeddingModelConfig extends MLModelConfig {
public static final String POOLING_MODE_FIELD = "pooling_mode";
public static final String NORMALIZE_RESULT_FIELD = "normalize_result";
public static final String MODEL_MAX_LENGTH_FIELD = "model_max_length";
public static final String QUERY_PREFIX = "query_prefix";
public static final String PASSAGE_PREFIX = "passage_prefix";

private final Integer embeddingDimension;
private final FrameworkType frameworkType;
private final PoolingMode poolingMode;
private final boolean normalizeResult;
private final Integer modelMaxLength;
private final String queryPrefix;
private final String passagePrefix;

public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, FrameworkType frameworkType, String allConfig,
PoolingMode poolingMode, boolean normalizeResult, Integer modelMaxLength) {
this(modelType, embeddingDimension, frameworkType, allConfig, poolingMode, normalizeResult, modelMaxLength, null, null);
}

@Builder(toBuilder = true)
public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, FrameworkType frameworkType, String allConfig,
PoolingMode poolingMode, boolean normalizeResult, Integer modelMaxLength) {
PoolingMode poolingMode, boolean normalizeResult, Integer modelMaxLength, String queryPrefix, String passagePrefix) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not add a check function here to ensure at least one of the new parameters is not Null? It looks like this check can avoid problems later in loading the asymmetric model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TextEmbeddingModelConfig is used for both asymmetric and symmetric models. Symmetric models don't have prefixes. The alternative would be to create a new TextEmbeddingModelConfig class for asymmetric models. I don't think this would be a cleaner solution, tbh.

super(modelType, allConfig);
if (embeddingDimension == null) {
throw new IllegalArgumentException("embedding dimension is null");
Expand All @@ -59,6 +68,8 @@ public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, Fr
this.poolingMode = poolingMode;
this.normalizeResult = normalizeResult;
this.modelMaxLength = modelMaxLength;
this.queryPrefix = queryPrefix;
this.passagePrefix = passagePrefix;
}

public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOException {
Expand All @@ -69,6 +80,8 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc
PoolingMode poolingMode = null;
boolean normalizeResult = false;
Integer modelMaxLength = null;
String queryPrefix = null;
String passagePrefix = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -97,12 +110,18 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc
case MODEL_MAX_LENGTH_FIELD:
modelMaxLength = parser.intValue();
break;
case QUERY_PREFIX:
queryPrefix = parser.text();
break;
br3no marked this conversation as resolved.
Show resolved Hide resolved
case PASSAGE_PREFIX:
passagePrefix = parser.text();
break;
br3no marked this conversation as resolved.
Show resolved Hide resolved
default:
parser.skipChildren();
break;
}
}
return new TextEmbeddingModelConfig(modelType, embeddingDimension, frameworkType, allConfig, poolingMode, normalizeResult, modelMaxLength);
return new TextEmbeddingModelConfig(modelType, embeddingDimension, frameworkType, allConfig, poolingMode, normalizeResult, modelMaxLength, queryPrefix, passagePrefix);
}

@Override
Expand All @@ -121,6 +140,8 @@ public TextEmbeddingModelConfig(StreamInput in) throws IOException{
}
normalizeResult = in.readBoolean();
modelMaxLength = in.readOptionalInt();
queryPrefix = in.readOptionalString();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can focus more on the feature itself in this PR. But later when we release this in new version. We should add version check. @b4sjoo Can you help when release new version.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both queryPrefix and passagePrefix are optional, so this shouldn't break compatibility. Are you sure there is need for a version check?

passagePrefix = in.readOptionalString();
}

@Override
Expand All @@ -136,6 +157,8 @@ public void writeTo(StreamOutput out) throws IOException {
}
out.writeBoolean(normalizeResult);
out.writeOptionalInt(modelMaxLength);
out.writeOptionalString(queryPrefix);
out.writeOptionalString(passagePrefix);
}

@Override
Expand All @@ -162,6 +185,12 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (normalizeResult) {
builder.field(NORMALIZE_RESULT_FIELD, normalizeResult);
}
if (queryPrefix != null) {
builder.field(QUERY_PREFIX, queryPrefix);
}
if (passagePrefix != null) {
builder.field(PASSAGE_PREFIX, passagePrefix);
}
builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ public void downloadPrebuiltModelConfig(
case TextEmbeddingModelConfig.MODEL_MAX_LENGTH_FIELD:
configBuilder.modelMaxLength(((Double) configEntry.getValue()).intValue());
break;
case TextEmbeddingModelConfig.QUERY_PREFIX:
configBuilder.queryPrefix(configEntry.getValue().toString());
break;
br3no marked this conversation as resolved.
Show resolved Hide resolved
case TextEmbeddingModelConfig.PASSAGE_PREFIX:
configBuilder.passagePrefix(configEntry.getValue().toString());
break;
default:
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ public abstract class DLModel implements Predictable {
protected Device[] devices;
protected AtomicInteger nextDevice = new AtomicInteger(0);

protected MLModelConfig modelConfig;

@Override
public MLOutput predict(MLInput mlInput, MLModel model) {
throw new IllegalArgumentException("model not deployed");
Expand Down Expand Up @@ -183,6 +185,7 @@ protected void doLoadModel(
IOException,
TranslateException {
devices = Engine.getEngine(engine).getDevices();
this.modelConfig = modelConfig;
for (int i = 0; i < devices.length; i++) {
log.debug("load model {} to device {}: {}", modelId, i, devices[i]);
ZooModel<Input, Output> model;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,35 @@
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.output.model.ModelResultFilter;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.algorithms.text_embedding.AsymmetricTextEmbeddingParameters;
import org.opensearch.ml.engine.algorithms.text_embedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType;

import ai.djl.inference.Predictor;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.translate.TranslateException;

public abstract class TextEmbeddingModel extends DLModel {

@Override
public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException {
MLInputDataset inputDataSet = mlInput.getInputDataset();
MLAlgoParams mlParams = mlInput.getParameters();
if (mlParams != null && mlParams instanceof AsymmetricTextEmbeddingParameters) {
addPrefixesToData((AsymmetricTextEmbeddingParameters) mlParams, (TextDocsInputDataSet) inputDataSet);
} else if (modelConfig != null
&& (((TextEmbeddingModelConfig) modelConfig).getPassagePrefix() != null
|| ((TextEmbeddingModelConfig) modelConfig).getQueryPrefix() != null)) {
throw new IllegalArgumentException(
"The embedding model chosen is asymmetric. To use it, you must declare whether the input is a query or a passage."
);
br3no marked this conversation as resolved.
Show resolved Hide resolved
}
List<ModelTensors> tensorOutputs = new ArrayList<>();
Output output;
TextDocsInputDataSet textDocsInput = (TextDocsInputDataSet) inputDataSet;
Expand All @@ -36,6 +50,19 @@ public ModelTensorOutput predict(String modelId, MLInput mlInput) throws Transla
return new ModelTensorOutput(tensorOutputs);
}

private void addPrefixesToData(AsymmetricTextEmbeddingParameters mlParams, TextDocsInputDataSet inputDataSet) {
// Asymmetric embedding models typically work with "mini-prompts" that prime the model to embed a text
// as a query or as a passage. Here we apply the prompt as defined in the model configuration. We default
// to passage embedding.
br3no marked this conversation as resolved.
Show resolved Hide resolved
TextEmbeddingModelConfig modelConfig = (TextEmbeddingModelConfig) this.modelConfig;
String prefix = mlParams.getEmbeddingContentType() == EmbeddingContentType.QUERY
? modelConfig.getQueryPrefix()
: modelConfig.getPassagePrefix();
if (prefix != null) {
inputDataSet.getDocs().replaceAll(doc -> prefix + doc);
br3no marked this conversation as resolved.
Show resolved Hide resolved
}
}

public void warmUp(Predictor predictor, String modelId, MLModelConfig modelConfig) throws TranslateException {
TextEmbeddingModelConfig textEmbeddingModelConfig = (TextEmbeddingModelConfig) modelConfig;
String warmUpSentence = "warm up sentence";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.algorithms.text_embedding;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

import java.io.IOException;
import java.util.Locale;

import org.opensearch.core.ParseField;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.annotation.MLAlgoParameter;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;

import lombok.Builder;
import lombok.Data;

/**
* This class defines the modes of operation of an asymmetric text embedding algorithm. The algorithm can be used to embed either a query or a passage.
*/
@Data
@MLAlgoParameter(algorithms = { FunctionName.TEXT_EMBEDDING })
public class AsymmetricTextEmbeddingParameters implements MLAlgoParams {

public enum EmbeddingContentType {
QUERY,
PASSAGE
}

public static final String PARSE_FIELD_NAME = FunctionName.TEXT_EMBEDDING.name();
public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry(
MLAlgoParams.class,
new ParseField(PARSE_FIELD_NAME),
it -> parse(it)
);

@Builder(toBuilder = true)
public AsymmetricTextEmbeddingParameters(EmbeddingContentType embeddingContentType) {
this.embeddingContentType = embeddingContentType;
}

public AsymmetricTextEmbeddingParameters(StreamInput in) throws IOException {
this.embeddingContentType = EmbeddingContentType.valueOf(in.readOptionalString());
}

public static MLAlgoParams parse(XContentParser parser) throws IOException {
EmbeddingContentType embeddingContentType = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case EMBEDDING_CONTENT_TYPE_FIELD:
String contentType = parser.text();
embeddingContentType = EmbeddingContentType.valueOf(contentType.toUpperCase(Locale.ROOT));
break;
default:
parser.skipChildren();
break;
}
}
return new AsymmetricTextEmbeddingParameters(embeddingContentType);
}

public static final String EMBEDDING_CONTENT_TYPE_FIELD = "content_type";

// The type of the content to be embedded
private EmbeddingContentType embeddingContentType;

@Override
public int getVersion() {
return 1;
}

@Override
public String getWriteableName() {
return PARSE_FIELD_NAME;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(embeddingContentType.name());
}

@Override
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
xContentBuilder.startObject();
if (embeddingContentType != null) {
xContentBuilder.field(EMBEDDING_CONTENT_TYPE_FIELD, embeddingContentType.name());
}
xContentBuilder.endObject();
return xContentBuilder;
}

public EmbeddingContentType getEmbeddingContentType() {
return embeddingContentType;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.engine.algorithms.text_embedding;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType.HUGGINGFACE_TRANSFORMERS;
import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS;
import static org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingDenseModel.ML_ENGINE;
Expand Down Expand Up @@ -43,6 +44,7 @@
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.engine.algorithms.text_embedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.engine.utils.FileUtils;
Expand Down Expand Up @@ -239,6 +241,44 @@ private void initModel_predict_HuggingfaceModel(

}

@Test
public void initModel_predict_TorchScript_SentenceTransformer_SmallModel_With_Asymmetric_Prompts() throws URISyntaxException {
Map<String, Object> params = new HashMap<>();
params.put(MODEL_HELPER, modelHelper);
params.put(MODEL_ZIP_FILE, new File(getClass().getResource("traced_small_model.zip").toURI()));
params.put(ML_ENGINE, mlEngine);
TextEmbeddingModelConfig modelConfig = this.modelConfig.toBuilder().embeddingDimension(768).queryPrefix("query >> ").build();
MLModel smallModel = model.toBuilder().modelConfig(modelConfig).build();
textEmbeddingDenseModel.initModel(smallModel, params, encryptor);
MLInput mlInputQueries = MLInput
.builder()
.algorithm(FunctionName.TEXT_EMBEDDING)
.inputDataset(
TextDocsInputDataSet.builder().docs(Arrays.asList("what is the meaning of life?", "who won this year's us open")).build()
)
.parameters(new AsymmetricTextEmbeddingParameters(EmbeddingContentType.QUERY))
.build();
MLInput mlInputPassages = MLInput
.builder()
.algorithm(FunctionName.TEXT_EMBEDDING)
.inputDataset(
TextDocsInputDataSet.builder().docs(Arrays.asList("The meaning of life is 42", "I won this year's us open")).build()
)
.parameters(new AsymmetricTextEmbeddingParameters(EmbeddingContentType.PASSAGE))
.build();

textEmbeddingDenseModel.predict(mlInputQueries);
textEmbeddingDenseModel.predict(mlInputPassages);
TextDocsInputDataSet queries = (TextDocsInputDataSet) mlInputQueries.getInputDataset();
TextDocsInputDataSet passages = (TextDocsInputDataSet) mlInputPassages.getInputDataset();

assertTrue("all docs should start with query prefix", queries.getDocs().stream().allMatch(doc -> doc.startsWith("query >> ")));
assertEquals("passage 0 should remain unchanged", passages.getDocs().get(0), "The meaning of life is 42");
assertEquals("passage 1 should remain unchanged", passages.getDocs().get(1), "I won this year's us open");

textEmbeddingDenseModel.close();
}
br3no marked this conversation as resolved.
Show resolved Hide resolved

@Test
public void initModel_NullModelZipFile() {
exceptionRule.expect(IllegalArgumentException.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@
import org.opensearch.ml.engine.algorithms.anomalylocalization.AnomalyLocalizerImpl;
import org.opensearch.ml.engine.algorithms.metrics_correlation.MetricsCorrelation;
import org.opensearch.ml.engine.algorithms.sample.LocalSampleCalculator;
import org.opensearch.ml.engine.algorithms.text_embedding.AsymmetricTextEmbeddingParameters;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.engine.indices.MLIndicesHandler;
Expand Down Expand Up @@ -843,7 +844,8 @@ public List<NamedXContentRegistry.Entry> getNamedXContent() {
AnomalyLocalizationInput.XCONTENT_REGISTRY_ENTRY,
RCFSummarizeParams.XCONTENT_REGISTRY,
LogisticRegressionParams.XCONTENT_REGISTRY,
TextEmbeddingModelConfig.XCONTENT_REGISTRY
TextEmbeddingModelConfig.XCONTENT_REGISTRY,
AsymmetricTextEmbeddingParameters.XCONTENT_REGISTRY
);
}

Expand Down
Loading
Loading