Skip to content

Commit

Permalink
Add neural search default processor for non OpenAI/Cohere scenario (o…
Browse files Browse the repository at this point in the history
…pensearch-project#1274)

* Fix breaking change caused by opensearch core

Signed-off-by: zane-neo <[email protected]>

* Add neural search default pre/post process function support

Signed-off-by: zane-neo <[email protected]>

* Fix UT failures

Signed-off-by: zane-neo <[email protected]>

* Fix conflicts when backport

Signed-off-by: zane-neo <[email protected]>

* Fix conflict when backport

Signed-off-by: zane-neo <[email protected]>

---------

Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo authored and ylwu-amzn committed Oct 4, 2023
1 parent b392fdb commit 10def17
Show file tree
Hide file tree
Showing 14 changed files with 332 additions and 204 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@

package org.opensearch.ml.common.connector;


import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.text.StringSubstitutor;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
Expand All @@ -21,17 +32,6 @@
import org.opensearch.ml.common.MLCommonsClassLoader;
import org.opensearch.ml.common.output.model.ModelTensor;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.gson;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,61 +5,64 @@

package org.opensearch.ml.common.connector;

import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

public class MLPostProcessFunction {

private static Map<String, String> POST_PROCESS_FUNCTIONS;
public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding";
public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding";

public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding";

private static final Map<String, String> JSON_PATH_EXPRESSION = new HashMap<>();

private static final Map<String, Function<List<List<Float>>, List<ModelTensor>>> POST_PROCESS_FUNCTIONS = new HashMap<>();


static {
POST_PROCESS_FUNCTIONS = new HashMap<>();
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, "\n def name = \"sentence_embedding\";\n" +
" def dataType = \"FLOAT32\";\n" +
" if (params.embeddings == null || params.embeddings.length == 0) {\n" +
" return null;\n" +
" }\n" +
" def embeddings = params.embeddings;\n" +
" StringBuilder builder = new StringBuilder(\"[\");\n" +
" for (int i=0; i<embeddings.length; i++) {\n" +
" def shape = [embeddings[i].length];\n" +
" def json = \"{\" +\n" +
" \"\\\"name\\\":\\\"\" + name + \"\\\",\" +\n" +
" \"\\\"data_type\\\":\\\"\" + dataType + \"\\\",\" +\n" +
" \"\\\"shape\\\":\" + shape + \",\" +\n" +
" \"\\\"data\\\":\" + embeddings[i] +\n" +
" \"}\";\n" +
" builder.append(json);\n" +
" if (i < embeddings.length - 1) {\n" +
" builder.append(\",\");\n" +
" }\n" +
" }\n" +
" builder.append(\"]\");\n" +
" \n" +
" return builder.toString();\n ");
JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding");
JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings");
JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]");
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, buildModelTensorList());
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, buildModelTensorList());
POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, buildModelTensorList());
}

POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, "\n def name = \"sentence_embedding\";\n" +
" def dataType = \"FLOAT32\";\n" +
" if (params.data == null || params.data.length == 0) {\n" +
" return null;\n" +
" }\n" +
" def shape = [params.data[0].embedding.length];\n" +
" def json = \"{\" +\n" +
" \"\\\"name\\\":\\\"\" + name + \"\\\",\" +\n" +
" \"\\\"data_type\\\":\\\"\" + dataType + \"\\\",\" +\n" +
" \"\\\"shape\\\":\" + shape + \",\" +\n" +
" \"\\\"data\\\":\" + params.data[0].embedding +\n" +
" \"}\";\n" +
" return json;\n ");
public static Function<List<List<Float>>, List<ModelTensor>> buildModelTensorList() {
return embeddings -> {
List<ModelTensor> modelTensors = new ArrayList<>();
if (embeddings == null) {
throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function.");
}
embeddings.forEach(embedding -> modelTensors.add(
ModelTensor
.builder()
.name("sentence_embedding")
.dataType(MLResultDataType.FLOAT32)
.shape(new long[]{embedding.size()})
.data(embedding.toArray(new Number[0]))
.build()
));
return modelTensors;
};
}

public static boolean contains(String functionName) {
return POST_PROCESS_FUNCTIONS.containsKey(functionName);
public static String getResponseFilter(String postProcessFunction) {
return JSON_PATH_EXPRESSION.get(postProcessFunction);
}

public static String get(String postProcessFunction) {
public static Function<List<List<Float>>, List<ModelTensor>> get(String postProcessFunction) {
return POST_PROCESS_FUNCTIONS.get(postProcessFunction);
}

public static boolean contains(String postProcessFunction) {
return POST_PROCESS_FUNCTIONS.containsKey(postProcessFunction);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,37 @@
package org.opensearch.ml.common.connector;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

public class MLPreProcessFunction {

private static Map<String, String> PRE_PROCESS_FUNCTIONS;
private static final Map<String, Function<List<String>, Map<String, Object>>> PRE_PROCESS_FUNCTIONS = new HashMap<>();
public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding";
public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding";

public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding";

private static Function<List<String>, Map<String, Object>> cohereTextEmbeddingPreProcess() {
return inputs -> Map.of("parameters", Map.of("texts", inputs));
}

private static Function<List<String>, Map<String, Object>> openAiTextEmbeddingPreProcess() {
return inputs -> Map.of("parameters", Map.of("input", inputs));
}

static {
PRE_PROCESS_FUNCTIONS = new HashMap<>();
//TODO: change to java for openAI, embedding and Titan
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, "\n StringBuilder builder = new StringBuilder();\n" +
" builder.append(\"[\");\n" +
" for (int i=0; i< params.text_docs.length; i++) {\n" +
" builder.append(\"\\\"\");\n" +
" builder.append(params.text_docs[i]);\n" +
" builder.append(\"\\\"\");\n" +
" if (i < params.text_docs.length - 1) {\n" +
" builder.append(\",\")\n" +
" }\n" +
" }\n" +
" builder.append(\"]\");\n" +
" def parameters = \"{\" +\"\\\"texts\\\":\" + builder + \"}\";\n" +
" return \"{\" +\"\\\"parameters\\\":\" + parameters + \"}\";");

PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, "\n StringBuilder builder = new StringBuilder();\n" +
" builder.append(\"\\\"\");\n" +
" builder.append(params.text_docs[0]);\n" +
" builder.append(\"\\\"\");\n" +
" def parameters = \"{\" +\"\\\"input\\\":\" + builder + \"}\";\n" +
" return \"{\" +\"\\\"parameters\\\":\" + parameters + \"}\";");
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess());
}

public static boolean contains(String functionName) {
return PRE_PROCESS_FUNCTIONS.containsKey(functionName);
}

public static String get(String postProcessFunction) {
public static Function<List<String>, Map<String, Object>> get(String postProcessFunction) {
return PRE_PROCESS_FUNCTIONS.get(postProcessFunction);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,9 @@
import org.opensearch.ml.common.utils.StringUtils;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.HashMap;
import java.util.Map;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.gson;

@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.REMOTE})
public class RemoteInferenceMLInput extends MLInput {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
public class StringUtils {

public static final Gson gson;

static {
gson = new Gson();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,21 @@
package org.opensearch.ml.common.connector;

import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import static org.opensearch.ml.common.connector.MLPostProcessFunction.OPENAI_EMBEDDING;

public class MLPostProcessFunctionTest {

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

@Test
public void contains() {
Assert.assertTrue(MLPostProcessFunction.contains(OPENAI_EMBEDDING));
Expand All @@ -23,4 +32,24 @@ public void get() {
Assert.assertNotNull(MLPostProcessFunction.get(OPENAI_EMBEDDING));
Assert.assertNull(MLPostProcessFunction.get("wrong value"));
}

@Test
public void test_getResponseFilter() {
assert null != MLPostProcessFunction.getResponseFilter(OPENAI_EMBEDDING);
assert null == MLPostProcessFunction.getResponseFilter("wrong value");
}

@Test
public void test_buildModelTensorList() {
Assert.assertNotNull(MLPostProcessFunction.buildModelTensorList());
List<List<Float>> numbersList = new ArrayList<>();
numbersList.add(Collections.singletonList(1.0f));
Assert.assertNotNull(MLPostProcessFunction.buildModelTensorList().apply(numbersList));
}

@Test
public void test_buildModelTensorList_exception() {
exceptionRule.expect(IllegalArgumentException.class);
MLPostProcessFunction.buildModelTensorList().apply(null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import ai.djl.training.util.DownloadUtils;
import ai.djl.training.util.ProgressBar;
import com.google.gson.Gson;
import com.google.gson.stream.JsonReader;
import lombok.extern.log4j.Log4j2;
import org.opensearch.action.ActionListener;
Expand All @@ -32,6 +31,7 @@
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;

import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.engine.utils.FileUtils.calculateFileHash;
import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly;
import static org.opensearch.ml.engine.utils.FileUtils.splitFileIntoChunks;
Expand All @@ -48,11 +48,9 @@ public class ModelHelper {
public static final String PYTORCH_ENGINE = "PyTorch";
public static final String ONNX_ENGINE = "OnnxRuntime";
private final MLEngine mlEngine;
private Gson gson;

public ModelHelper(MLEngine mlEngine) {
this.mlEngine = mlEngine;
gson = new Gson();
}

public void downloadPrebuiltModelConfig(String taskId, MLRegisterModelInput registerModelInput, ActionListener<MLRegisterModelInput> listener) {
Expand Down
Loading

0 comments on commit 10def17

Please sign in to comment.