Skip to content

Commit

Permalink
Add process function for bedrock (#1554) (#1948)
Browse files Browse the repository at this point in the history
* Add process function for bedrock

Signed-off-by: zane-neo <[email protected]>

* Merge two functions together

Signed-off-by: zane-neo <[email protected]>

* change method name back

Signed-off-by: zane-neo <[email protected]>

* Fix compile issue after rebase

Signed-off-by: zane-neo <[email protected]>

* format code

Signed-off-by: zane-neo <[email protected]>

* Fix compile issue after merging methods

Signed-off-by: zane-neo <[email protected]>

---------

Signed-off-by: zane-neo <[email protected]>
(cherry picked from commit 33977a1)

Co-authored-by: zane-neo <[email protected]>
  • Loading branch information
opensearch-trigger-bot[bot] and zane-neo authored Jan 29, 2024
1 parent bd8b6ab commit 152c5e2
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.common.connector;

import com.google.common.collect.ImmutableList;
import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;

Expand All @@ -18,38 +19,46 @@ public class MLPostProcessFunction {

public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding";
public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding";

public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding";
public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding";

private static final Map<String, String> JSON_PATH_EXPRESSION = new HashMap<>();

private static final Map<String, Function<List<List<Float>>, List<ModelTensor>>> POST_PROCESS_FUNCTIONS = new HashMap<>();
private static final Map<String, Function<List<?>, List<ModelTensor>>> POST_PROCESS_FUNCTIONS = new HashMap<>();


static {
JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding");
JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings");
JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]");
JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding");
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, buildModelTensorList());
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, buildModelTensorList());
POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, buildModelTensorList());
POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, buildModelTensorList());
}

public static Function<List<List<Float>>, List<ModelTensor>> buildModelTensorList() {
public static Function<List<?>, List<ModelTensor>> buildModelTensorList() {
return embeddings -> {
List<ModelTensor> modelTensors = new ArrayList<>();
if (embeddings == null) {
throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function.");
}
embeddings.forEach(embedding -> modelTensors.add(
ModelTensor
.builder()
.name("sentence_embedding")
.dataType(MLResultDataType.FLOAT32)
.shape(new long[]{embedding.size()})
.data(embedding.toArray(new Number[0]))
.build()
));
if (embeddings.get(0) instanceof Number) {
embeddings = ImmutableList.of(embeddings);
}
embeddings.forEach(embedding -> {
List<Number> eachEmbedding = (List<Number>) embedding;
modelTensors.add(
ModelTensor
.builder()
.name("sentence_embedding")
.dataType(MLResultDataType.FLOAT32)
.shape(new long[]{eachEmbedding.size()})
.data(eachEmbedding.toArray(new Number[0]))
.build()
);
});
return modelTensors;
};
}
Expand All @@ -58,7 +67,7 @@ public static String getResponseFilter(String postProcessFunction) {
return JSON_PATH_EXPRESSION.get(postProcessFunction);
}

public static Function<List<List<Float>>, List<ModelTensor>> get(String postProcessFunction) {
public static Function<List<?>, List<ModelTensor>> get(String postProcessFunction) {
return POST_PROCESS_FUNCTIONS.get(postProcessFunction);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public class MLPreProcessFunction {
private static final Map<String, Function<List<String>, Map<String, Object>>> PRE_PROCESS_FUNCTIONS = new HashMap<>();
public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding";
public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding";

public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding";
public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding";

private static Function<List<String>, Map<String, Object>> cohereTextEmbeddingPreProcess() {
Expand All @@ -26,17 +26,22 @@ private static Function<List<String>, Map<String, Object>> openAiTextEmbeddingPr
return inputs -> Map.of("parameters", Map.of("input", inputs));
}

private static Function<List<String>, Map<String, Object>> bedrockTextEmbeddingPreProcess() {
return inputs -> Map.of("parameters", Map.of("inputText", inputs.get(0)));
}

static {
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockTextEmbeddingPreProcess());
}

public static boolean contains(String functionName) {
return PRE_PROCESS_FUNCTIONS.containsKey(functionName);
}

public static Function<List<String>, Map<String, Object>> get(String postProcessFunction) {
return PRE_PROCESS_FUNCTIONS.get(postProcessFunction);
public static Function<List<String>, Map<String, Object>> get(String preProcessFunction) {
return PRE_PROCESS_FUNCTIONS.get(preProcessFunction);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import java.util.Collections;
import java.util.List;

import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_EMBEDDING;
import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_EMBEDDING;
import static org.opensearch.ml.common.connector.MLPostProcessFunction.OPENAI_EMBEDDING;

public class MLPostProcessFunctionTest {
Expand All @@ -29,13 +31,13 @@ public void contains() {

@Test
public void get() {
Assert.assertNotNull(MLPostProcessFunction.get(OPENAI_EMBEDDING));
Assert.assertNotNull(MLPostProcessFunction.get(COHERE_EMBEDDING));
Assert.assertNull(MLPostProcessFunction.get("wrong value"));
}

@Test
public void test_getResponseFilter() {
assert null != MLPostProcessFunction.getResponseFilter(OPENAI_EMBEDDING);
assert null != MLPostProcessFunction.getResponseFilter(BEDROCK_EMBEDDING);
assert null == MLPostProcessFunction.getResponseFilter("wrong value");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ public static ModelTensors processOutput(
// in this case, we can use jsonpath to build a List<List<Float>> result from model response.
if (StringUtils.isBlank(responseFilter))
responseFilter = MLPostProcessFunction.getResponseFilter(postProcessFunction);
List<List<Float>> vectors = JsonPath.read(modelResponse, responseFilter);
List<?> vectors = JsonPath.read(modelResponse, responseFilter);
List<ModelTensor> processedResponse = executeBuildInPostProcessFunction(
vectors,
MLPostProcessFunction.get(postProcessFunction)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@ public static Optional<String> executePreprocessFunction(
return Optional.ofNullable(executeScript(scriptService, preProcessFunction, ImmutableMap.of("text_docs", inputSentences)));
}

public static List<ModelTensor> executeBuildInPostProcessFunction(
List<List<Float>> vectors,
Function<List<List<Float>>, List<ModelTensor>> function
) {
public static List<ModelTensor> executeBuildInPostProcessFunction(List<?> vectors, Function<List<?>, List<ModelTensor>> function) {
return function.apply(vectors);
}

Expand Down

0 comments on commit 152c5e2

Please sign in to comment.