-
Notifications
You must be signed in to change notification settings - Fork 138
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add support for asymmetric embeddings Signed-off-by: br3no <[email protected]> * fix NPE with sparse models Signed-off-by: br3no <[email protected]> * after review Signed-off-by: br3no <[email protected]> * improving code coverage Signed-off-by: br3no <[email protected]> * improving javadocs Signed-off-by: br3no <[email protected]> * adding unit-test for AsymmetricTextEmbeddingParameters Signed-off-by: br3no <[email protected]> * after latest review Signed-off-by: br3no <[email protected]> --------- Signed-off-by: br3no <[email protected]>
- Loading branch information
Showing
11 changed files
with
615 additions
and
10 deletions.
There are no files selected for viewing
114 changes: 114 additions & 0 deletions
114
common/src/main/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParameters.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.dataset; | ||
|
||
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; | ||
|
||
import java.io.IOException; | ||
import java.util.Locale; | ||
|
||
import org.opensearch.core.ParseField; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
import org.opensearch.core.xcontent.XContentBuilder; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
import org.opensearch.ml.common.FunctionName; | ||
import org.opensearch.ml.common.annotation.MLAlgoParameter; | ||
import org.opensearch.ml.common.input.parameter.MLAlgoParams; | ||
|
||
import lombok.Builder; | ||
import lombok.Data; | ||
|
||
/** | ||
* This class defines the modes of operation of an asymmetric text embedding model. | ||
* Asymmetric embedding models treat the input text differently, depending on whether it is a | ||
* passage or a query. One example asymmetric model, that requires different prefixes is e5 | ||
* (cf. https://arxiv.org/pdf/2212.03533.pdf). | ||
* <p> | ||
* Use this parameter only if the model is asymmetric and has been registered with the corresponding | ||
* `query_prefix` and `passage_prefix` configuration parameters. | ||
*/ | ||
@Data | ||
@MLAlgoParameter(algorithms = { FunctionName.TEXT_EMBEDDING }) | ||
public class AsymmetricTextEmbeddingParameters implements MLAlgoParams { | ||
|
||
public enum EmbeddingContentType { | ||
QUERY, | ||
PASSAGE | ||
} | ||
|
||
public static final String PARSE_FIELD_NAME = FunctionName.TEXT_EMBEDDING.name(); | ||
public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( | ||
MLAlgoParams.class, | ||
new ParseField(PARSE_FIELD_NAME), | ||
it -> parse(it) | ||
); | ||
|
||
@Builder(toBuilder = true) | ||
public AsymmetricTextEmbeddingParameters(EmbeddingContentType embeddingContentType) { | ||
this.embeddingContentType = embeddingContentType; | ||
} | ||
|
||
public AsymmetricTextEmbeddingParameters(StreamInput in) throws IOException { | ||
this.embeddingContentType = EmbeddingContentType.valueOf(in.readOptionalString()); | ||
} | ||
|
||
public static MLAlgoParams parse(XContentParser parser) throws IOException { | ||
EmbeddingContentType embeddingContentType = null; | ||
|
||
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); | ||
while (parser.nextToken() != XContentParser.Token.END_OBJECT) { | ||
String fieldName = parser.currentName(); | ||
parser.nextToken(); | ||
|
||
switch (fieldName) { | ||
case EMBEDDING_CONTENT_TYPE_FIELD: | ||
String contentType = parser.text(); | ||
embeddingContentType = EmbeddingContentType.valueOf(contentType.toUpperCase(Locale.ROOT)); | ||
break; | ||
default: | ||
parser.skipChildren(); | ||
break; | ||
} | ||
} | ||
return new AsymmetricTextEmbeddingParameters(embeddingContentType); | ||
} | ||
|
||
public static final String EMBEDDING_CONTENT_TYPE_FIELD = "content_type"; | ||
|
||
// The type of the content to be embedded | ||
private EmbeddingContentType embeddingContentType; | ||
|
||
@Override | ||
public int getVersion() { | ||
return 1; | ||
} | ||
|
||
@Override | ||
public String getWriteableName() { | ||
return PARSE_FIELD_NAME; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
out.writeOptionalString(embeddingContentType.name()); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { | ||
xContentBuilder.startObject(); | ||
if (embeddingContentType != null) { | ||
xContentBuilder.field(EMBEDDING_CONTENT_TYPE_FIELD, embeddingContentType.name()); | ||
} | ||
xContentBuilder.endObject(); | ||
return xContentBuilder; | ||
} | ||
|
||
public EmbeddingContentType getEmbeddingContentType() { | ||
return embeddingContentType; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
83 changes: 83 additions & 0 deletions
83
...src/test/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParametersTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
package org.opensearch.ml.common.dataset; | ||
|
||
import org.junit.Before; | ||
import org.junit.Rule; | ||
import org.junit.Test; | ||
import org.junit.rules.ExpectedException; | ||
import org.opensearch.common.io.stream.BytesStreamOutput; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
import org.opensearch.ml.common.TestHelper; | ||
|
||
import java.io.IOException; | ||
import java.util.function.Function; | ||
import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters.EmbeddingContentType; | ||
|
||
import static org.junit.Assert.assertEquals; | ||
import static org.opensearch.ml.common.TestHelper.contentObjectToString; | ||
import static org.opensearch.ml.common.TestHelper.testParseFromString; | ||
|
||
public class AsymmetricTextEmbeddingParametersTest { | ||
|
||
@Rule | ||
public ExpectedException exceptionRule = ExpectedException.none(); | ||
|
||
AsymmetricTextEmbeddingParameters params; | ||
private Function<XContentParser, AsymmetricTextEmbeddingParameters> function = parser -> { | ||
try { | ||
return (AsymmetricTextEmbeddingParameters) AsymmetricTextEmbeddingParameters.parse(parser); | ||
} catch (IOException e) { | ||
throw new RuntimeException("failed to parse AsymmetricTextEmbeddingParameters", e); | ||
} | ||
}; | ||
|
||
@Before | ||
public void setUp() { | ||
params = AsymmetricTextEmbeddingParameters.builder() | ||
.embeddingContentType(EmbeddingContentType.QUERY) | ||
.build(); | ||
} | ||
|
||
@Test | ||
public void parse_AsymmetricTextEmbeddingParameters() throws IOException { | ||
TestHelper.testParse(params, function); | ||
} | ||
|
||
@Test | ||
public void parse_AsymmetricTextEmbeddingParameters_Passage() throws IOException { | ||
String paramsStr = contentObjectToString(params); | ||
testParseFromString(params, paramsStr.replace("QUERY", "PASSAGE"), function); | ||
} | ||
|
||
@Test | ||
public void parse_AsymmetricTextEmbeddingParameters_Invalid() throws IOException { | ||
exceptionRule.expect(IllegalArgumentException.class); | ||
exceptionRule.expectMessage("No enum constant org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters.EmbeddingContentType.FU"); | ||
String paramsStr = contentObjectToString(params); | ||
testParseFromString(params, paramsStr.replace("QUERY","fu"), function); | ||
} | ||
|
||
@Test | ||
public void parse_EmptyAsymmetricTextEmbeddingParameters() throws IOException { | ||
TestHelper.testParse(AsymmetricTextEmbeddingParameters.builder().build(), function); | ||
} | ||
|
||
@Test | ||
public void readInputStream_Success() throws IOException { | ||
readInputStream(params); | ||
} | ||
|
||
@Test | ||
public void readInputStream_Success_EmptyParams() throws IOException { | ||
readInputStream(AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build()); | ||
} | ||
|
||
private void readInputStream(AsymmetricTextEmbeddingParameters params) throws IOException { | ||
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); | ||
params.writeTo(bytesStreamOutput); | ||
|
||
StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); | ||
AsymmetricTextEmbeddingParameters parsedParams = new AsymmetricTextEmbeddingParameters(streamInput); | ||
assertEquals(params, parsedParams); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.