Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tokenizer and sparse encoding #1301

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
e37956e
add tokenizer and sparse encoding
xinyual Aug 25, 2023
27d8fe0
add tokenizer and sparse encoding
xinyual Aug 25, 2023
0f8cd9c
add tokenizer and sparse encoding
xinyual Aug 25, 2023
af9ba42
add tokenizer and sparse encoding
xinyual Aug 25, 2023
6f49c22
add tokenizer and sparse encoding
xinyual Aug 25, 2023
2b5ea1b
remove special token
xinyual Aug 28, 2023
f832494
add filter
xinyual Aug 28, 2023
f49c471
try empty model
xinyual Aug 28, 2023
2be43a7
remove warm up
xinyual Aug 28, 2023
4b96521
try empty model
xinyual Aug 28, 2023
7b9f97e
add block
xinyual Aug 28, 2023
084c56f
add log
xinyual Aug 28, 2023
c2951cf
add log
xinyual Aug 28, 2023
61720c3
add log
xinyual Aug 28, 2023
99eda1a
remove log
xinyual Aug 28, 2023
5aa5698
remove pt file detect
xinyual Aug 28, 2023
1d4ddba
add log
xinyual Aug 28, 2023
b6614ca
add functionName pipeline
xinyual Aug 28, 2023
e3ca040
remove verify log
xinyual Aug 28, 2023
5eaf588
skip special token in sparse encoding
xinyual Aug 28, 2023
37a878f
skip omit tokenize config
xinyual Aug 29, 2023
4c4ac28
skip omit tokenize config-change warm up logic
xinyual Aug 29, 2023
03ce4af
reArch
xinyual Aug 29, 2023
eac759d
deduplicate
xinyual Aug 29, 2023
c0513b4
omit ml config in sparse encoding
xinyual Aug 29, 2023
0428cf4
add null config in warm up
xinyual Aug 29, 2023
51eef93
fix original test
xinyual Aug 29, 2023
17b755f
add tokenize ut half
xinyual Aug 29, 2023
418aa30
fix sparse encoding bug
xinyual Aug 30, 2023
1fe20e2
add UT for sparse encoding and tokenize
xinyual Aug 30, 2023
df1ca31
remove useless framwork type
xinyual Aug 30, 2023
16b5c89
common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java
xinyual Aug 31, 2023
2a15287
change key for tokenize
xinyual Aug 31, 2023
5ddbb24
reArch DLModel
xinyual Aug 31, 2023
401a7a6
reArch DLModel again
xinyual Aug 31, 2023
2c85cea
response format
xinyual Aug 31, 2023
64a6e6b
tokenize only one output
xinyual Aug 31, 2023
eb7ea9c
clean sparse output
xinyual Aug 31, 2023
65a72ca
clean sparse output
xinyual Aug 31, 2023
f898d6d
change UT number
xinyual Aug 31, 2023
d50d64a
remove useless predict code
xinyual Sep 4, 2023
f6428ee
remove useless part
xinyual Sep 4, 2023
46b0f00
change tokenize way
xinyual Sep 5, 2023
1b57385
reArch add textEmbedding model
xinyual Sep 5, 2023
21bff22
add tokenize logic
xinyual Sep 5, 2023
d5581b6
add abstract
xinyual Sep 5, 2023
ae39b41
clear code
xinyual Sep 5, 2023
2b8529a
fix it class
xinyual Sep 6, 2023
d32e273
fix it class
xinyual Sep 6, 2023
61ac300
add IT file
xinyual Sep 6, 2023
c957d12
reformulate
xinyual Sep 7, 2023
7924cc5
reformulate remote inference
xinyual Sep 7, 2023
71338fa
reformulate remote inference
xinyual Sep 7, 2023
09a6acd
reformulate remote inference json and array
xinyual Sep 7, 2023
84b7006
verify
xinyual Sep 8, 2023
dc30251
undo string utils
xinyual Sep 8, 2023
2ab086e
skip dummy model
xinyual Sep 11, 2023
1e3a2f3
skip dummy model
xinyual Sep 11, 2023
67db622
skip dummy model
xinyual Sep 11, 2023
2748b97
skip dummy model
xinyual Sep 11, 2023
44a5bbd
skip dummy model
xinyual Sep 11, 2023
1e3def0
skip dummy model
xinyual Sep 11, 2023
086f6b0
add inner load Model
xinyual Sep 11, 2023
abc4064
rename variable
xinyual Sep 11, 2023
976d04f
add default for idf
xinyual Sep 12, 2023
89fe98f
add ut for sparse encoding and tokenizer
xinyual Sep 12, 2023
7173590
add close model
xinyual Sep 12, 2023
98a69b2
change mock class
xinyual Sep 12, 2023
2ab6ccf
remove buffer for sparse encoding output
xinyual Sep 12, 2023
0bea7f1
change tokenize model ready logic
xinyual Sep 13, 2023
65a5ed0
rewrite input functionName
xinyual Sep 14, 2023
55e60d9
deduplicate
xinyual Sep 14, 2023
7e9f015
change UT usage
xinyual Sep 14, 2023
86cc578
fix downloadAndSplit test
xinyual Sep 14, 2023
a9cb526
fix Helper test
xinyual Sep 14, 2023
e1c9359
remove meaningless change
xinyual Sep 18, 2023
31222d0
remove complie change
xinyual Sep 19, 2023
1214437
rename
xinyual Sep 21, 2023
fdaba84
fix typo error and simplify wrap code
xinyual Sep 21, 2023
4b96a09
add comment
xinyual Sep 25, 2023
185e95b
using gson and remove useless close logic
xinyual Sep 25, 2023
4f8847c
update comment and import problem
xinyual Sep 26, 2023
7f04f4c
add static idf name
xinyual Sep 26, 2023
4e3dc78
fix format problem
xinyual Sep 26, 2023
a837080
extract an abstract model for sparse and dense sentence transformer t…
xinyual Sep 26, 2023
dca433a
fix typo error
xinyual Sep 26, 2023
73b3f02
remove duplicate tokenizer file, fix import problem and add comment f…
xinyual Sep 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public class CommonValue {
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 2;
public static final String ML_CONFIG_INDEX = ".plugins-ml-config";
public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 2;
public static final String ML_MAP_RESPONSE_KEY = "response";
public static final String USER_FIELD_MAPPING = " \""
+ CommonValue.USER
+ "\": {\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ public enum FunctionName {
RCF_SUMMARIZE,
LOGISTIC_REGRESSION,
TEXT_EMBEDDING,
SPARSE_ENCODING,
SPARSE_TOKENIZE,
METRICS_CORRELATION,
REMOTE;

Expand All @@ -33,7 +35,7 @@ public static FunctionName from(String value) {
* @return true for deep learning model.
*/
public static boolean isDLModel(FunctionName functionName) {
if (functionName == TEXT_EMBEDDING) {
if (functionName == TEXT_EMBEDDING || functionName == SPARSE_ENCODING || functionName == SPARSE_TOKENIZE) {
return true;
}
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.Map;
import java.util.Optional;

import static org.opensearch.ml.common.CommonValue.ML_MAP_RESPONSE_KEY;
import static org.opensearch.ml.common.utils.StringUtils.isJson;

@Getter
Expand Down Expand Up @@ -101,7 +102,7 @@ public <T> void parseResponse(T response, List<ModelTensor> modelTensors, boolea
return;
}
if (response instanceof String && isJson((String)response)) {
Map<String, Object> data = StringUtils.fromJson((String) response, "response");
Map<String, Object> data = StringUtils.fromJson((String) response, ML_MAP_RESPONSE_KEY);
modelTensors.add(ModelTensor.builder().name("response").dataAsMap(data).build());
} else {
Map<String, Object> map = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws
}
}
MLInputDataset inputDataSet = null;
if (algorithm == FunctionName.TEXT_EMBEDDING) {
if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.SPARSE_ENCODING || algorithm == FunctionName.SPARSE_TOKENIZE) {
ModelResultFilter filter = new ModelResultFilter(returnBytes, returnNumber, targetResponse, targetResponsePositions);
inputDataSet = new TextDocsInputDataSet(textDocs, filter);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
* ML input class which supports a list fo text docs.
* This class can be used for TEXT_EMBEDDING model.
*/
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.TEXT_EMBEDDING})
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.TEXT_EMBEDDING, FunctionName.SPARSE_ENCODING, FunctionName.SPARSE_TOKENIZE})
public class TextDocsMLInput extends MLInput {
public static final String TEXT_DOCS_FIELD = "text_docs";
public static final String RESULT_FILTER_FIELD = "result_filter";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLCommonsClassLoader;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
Expand Down Expand Up @@ -104,7 +103,7 @@ public MLRegisterModelInput(FunctionName functionName,
if (modelFormat == null) {
throw new IllegalArgumentException("model format is null");
}
if (url != null && modelConfig == null) {
if (url != null && modelConfig == null && functionName != FunctionName.SPARSE_TOKENIZE && functionName != FunctionName.SPARSE_ENCODING) { // The tokenize model doesn't require a model configuration. Currently, we only support one type of sparse model, which is pretrained, and it doesn't necessitate a model configuration.
throw new IllegalArgumentException("model config is null");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m
if (modelContentHashValue == null) {
throw new IllegalArgumentException("model content hash value is null");
}
if (modelConfig == null) {
if (modelConfig == null && functionName != FunctionName.SPARSE_TOKENIZE && functionName != FunctionName.SPARSE_ENCODING) { // The tokenize model doesn't require a model configuration. Currently, we only support one type of sparse model, which is pretrained, and it doesn't necessitate a model configuration.
throw new IllegalArgumentException("model config is null");
}
if (totalChunks == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,21 +149,27 @@ public void testClassLoader_ExecuteOutputMCorr() throws IOException {
assertArrayEquals(new long[]{1, 2}, metrics);
}

@Test
public void testClassLoader_MLInput() throws IOException {
assertTrue(MLCommonsClassLoader.canInitMLInput(FunctionName.TEXT_EMBEDDING));
private void testClassLoader_MLInput_DlModel(FunctionName functionName) throws IOException {
assertTrue(MLCommonsClassLoader.canInitMLInput(functionName));

String jsonStr = "{\"text_docs\":[\"doc1\",\"doc2\"],\"result_filter\":{\"return_bytes\":true,\"return_number\":true,\"target_response\":[\"field1\"], \"target_response_positions\": [2]}}";
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();

TextDocsMLInput mlInput = MLCommonsClassLoader.initMLInput(FunctionName.TEXT_EMBEDDING, new Object[]{parser, FunctionName.TEXT_EMBEDDING}, XContentParser.class, FunctionName.class);
TextDocsMLInput mlInput = MLCommonsClassLoader.initMLInput(functionName, new Object[]{parser, functionName}, XContentParser.class, FunctionName.class);
assertNotNull(mlInput);
assertEquals(FunctionName.TEXT_EMBEDDING, mlInput.getFunctionName());
assertEquals(functionName, mlInput.getFunctionName());
assertEquals(2, ((TextDocsInputDataSet)mlInput.getInputDataset()).getDocs().size());
}

@Test
public void testClassLoader_MLInput() throws IOException {
testClassLoader_MLInput_DlModel(FunctionName.TEXT_EMBEDDING);
testClassLoader_MLInput_DlModel(FunctionName.SPARSE_TOKENIZE);
testClassLoader_MLInput_DlModel(FunctionName.SPARSE_ENCODING);
}

public enum TestEnum {
TEST
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.core.common.Strings;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
Expand Down Expand Up @@ -110,19 +108,19 @@ public void parse_LinearRegression() throws IOException {
});
}

@Test
public void parse_TextEmbedding() throws IOException {
private void parse_NLPModel(FunctionName functionName) throws IOException {
String sentence = "test sentence";
String column = "column1";
Integer position = 1;
ModelResultFilter resultFilter = ModelResultFilter.builder()
.targetResponse(Arrays.asList(column))
.targetResponsePositions(Arrays.asList(position))
.build();
TextDocsInputDataSet inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList(sentence))
.resultFilter(resultFilter).build();
String expectedInputStr = "{\"algorithm\":\"TEXT_EMBEDDING\",\"text_docs\":[\"test sentence\"],\"return_bytes\":false,\"return_number\":false,\"target_response\":[\"column1\"],\"target_response_positions\":[1]}";
testParse(FunctionName.TEXT_EMBEDDING, inputDataset, expectedInputStr, parsedInput -> {

TextDocsInputDataSet inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList(sentence)).resultFilter(resultFilter).build();
String expectedInputStr = "{\"algorithm\":\"functionName\",\"text_docs\":[\"test sentence\"],\"return_bytes\":false,\"return_number\":false,\"target_response\":[\"column1\"],\"target_response_positions\":[1]}";
expectedInputStr = expectedInputStr.replace("functionName", functionName.toString());
testParse(functionName, inputDataset, expectedInputStr, parsedInput -> {
assertNotNull(parsedInput.getInputDataset());
TextDocsInputDataSet parsedInputDataSet = (TextDocsInputDataSet) parsedInput.getInputDataset();
assertEquals(1, parsedInputDataSet.getDocs().size());
Expand All @@ -134,19 +132,33 @@ public void parse_TextEmbedding() throws IOException {
}

@Test
public void parse_TextEmbedding_NullResultFilter() throws IOException {
public void parse_NLP_Related() throws IOException {
parse_NLPModel(FunctionName.TEXT_EMBEDDING);
parse_NLPModel(FunctionName.SPARSE_TOKENIZE);
parse_NLPModel(FunctionName.SPARSE_ENCODING);
}

private void parse_NLPModel_NullResultFilter(FunctionName functionName) throws IOException {
String sentence = "test sentence";
TextDocsInputDataSet inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList(sentence)).build();
String expectedInputStr = "{\"algorithm\":\"TEXT_EMBEDDING\",\"text_docs\":[\"test sentence\"]}";
testParse(FunctionName.TEXT_EMBEDDING, inputDataset, expectedInputStr, parsedInput -> {
String expectedInputStr = "{\"algorithm\":\"functionName\",\"text_docs\":[\"test sentence\"]}";
expectedInputStr = expectedInputStr.replace("functionName", functionName.toString());
testParse(functionName, inputDataset, expectedInputStr, parsedInput -> {
assertNotNull(parsedInput.getInputDataset());
assertEquals(1, ((TextDocsInputDataSet) parsedInput.getInputDataset()).getDocs().size());
assertEquals(sentence, ((TextDocsInputDataSet) parsedInput.getInputDataset()).getDocs().get(0));
});
}

private void testParse(FunctionName algorithm, MLInputDataset inputDataset, String expectedInputStr,
Consumer<MLInput> verify) throws IOException {

@Test
public void parse_NLPRelated_NullResultFilter() throws IOException {
parse_NLPModel_NullResultFilter(FunctionName.TEXT_EMBEDDING);
parse_NLPModel_NullResultFilter(FunctionName.SPARSE_TOKENIZE);
parse_NLPModel_NullResultFilter(FunctionName.SPARSE_ENCODING);
}

private void testParse(FunctionName algorithm, MLInputDataset inputDataset, String expectedInputStr, Consumer<MLInput> verify) throws IOException {
MLInput input = MLInput.builder().inputDataset(inputDataset).algorithm(algorithm).build();
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
input.toXContent(builder, ToXContent.EMPTY_PARAMS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ GET /_plugins/_ml/profile/models/zwla5YUB1qmVrJFlwzXJ
"models": {
"zwla5YUB1qmVrJFlwzXJ": { # model id
"model_state": "LOADED",
"predictor": "org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingModel@1a0b0793",
"predictor": "org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingDenseModel@1a0b0793",
"target_worker_nodes": [ # plan to deploy model to these nodes
"0TLL4hHxRv6_G3n6y1l0BQ"
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ public List downloadPrebuiltModelMetaList(String taskId, MLRegisterModelInput re
* @param modelContentHash model content hash value
* @param listener action listener
*/
public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String modelName, String version, String url, String modelContentHash, ActionListener<Map<String, Object>> listener) {
public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String modelName, String version, String url, String modelContentHash, FunctionName functionName, ActionListener<Map<String, Object>> listener) {
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
Path registerModelPath = mlEngine.getRegisterModelPath(taskId, modelName, version);
Expand All @@ -200,7 +200,7 @@ public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String mo
File modelZipFile = new File(modelPath);
log.debug("download model to file {}", modelZipFile.getAbsolutePath());
DownloadUtils.download(url, modelPath, new ProgressBar());
verifyModelZipFile(modelFormat, modelPath, modelName);
verifyModelZipFile(modelFormat, modelPath, modelName, functionName);
String hash = calculateFileHash(modelZipFile);
if (hash.equals(modelContentHash)) {
List<String> chunkFiles = splitFileIntoChunks(modelZipFile, modelPartsPath, CHUNK_SIZE);
Expand All @@ -222,7 +222,7 @@ public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String mo
}
}

public void verifyModelZipFile(MLModelFormat modelFormat, String modelZipFilePath, String modelName) throws IOException {
public void verifyModelZipFile(MLModelFormat modelFormat, String modelZipFilePath, String modelName, FunctionName functionName) throws IOException {
boolean hasPtFile = false;
boolean hasOnnxFile = false;
boolean hasTokenizerFile = false;
Expand All @@ -237,7 +237,7 @@ public void verifyModelZipFile(MLModelFormat modelFormat, String modelZipFilePat
}
}
}
if (!hasPtFile && !hasOnnxFile) {
if (!hasPtFile && !hasOnnxFile && functionName != FunctionName.SPARSE_TOKENIZE) { // sparse tokenizer model doesn't need model file.
throw new IllegalArgumentException("Can't find model file");
}
if (!hasTokenizerFile) {
Expand Down
Loading
Loading