Skip to content

Commit

Permalink
support model embedding types in bedrock and cohere post process func…
Browse files Browse the repository at this point in the history
…tion

Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Dec 9, 2024
1 parent 1d30671 commit 58dd965
Show file tree
Hide file tree
Showing 12 changed files with 530 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@
package org.opensearch.ml.common.connector;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

import org.opensearch.ml.common.connector.functions.postprocess.BedrockBatchJobArnPostProcessFunction;
import org.opensearch.ml.common.connector.functions.postprocess.BedrockEmbeddingPostProcessFunction;
import org.opensearch.ml.common.connector.functions.postprocess.CohereRerankPostProcessFunction;
import org.opensearch.ml.common.connector.functions.postprocess.ConnectorPostProcessFunction;
import org.opensearch.ml.common.connector.functions.postprocess.EmbeddingPostProcessFunction;
import org.opensearch.ml.common.output.model.ModelTensor;

public class MLPostProcessFunction {

Expand All @@ -28,7 +26,7 @@ public class MLPostProcessFunction {

private static final Map<String, String> JSON_PATH_EXPRESSION = new HashMap<>();

private static final Map<String, Function<Object, List<ModelTensor>>> POST_PROCESS_FUNCTIONS = new HashMap<>();
private static final Map<String, ConnectorPostProcessFunction> POST_PROCESS_FUNCTIONS = new HashMap<>();

static {
EmbeddingPostProcessFunction embeddingPostProcessFunction = new EmbeddingPostProcessFunction();
Expand All @@ -55,7 +53,7 @@ public static String getResponseFilter(String postProcessFunction) {
return JSON_PATH_EXPRESSION.get(postProcessFunction);
}

public static Function<Object, List<ModelTensor>> get(String postProcessFunction) {
public static ConnectorPostProcessFunction get(String postProcessFunction) {
return POST_PROCESS_FUNCTIONS.get(postProcessFunction);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import org.opensearch.ml.common.output.model.ModelTensor;

public class BedrockBatchJobArnPostProcessFunction extends ConnectorPostProcessFunction<Map<String, String>> {
public class BedrockBatchJobArnPostProcessFunction implements ConnectorPostProcessFunction {
public static final String JOB_ARN = "jobArn";
public static final String PROCESSED_JOB_ARN = "processedJobArn";

Expand All @@ -28,7 +28,8 @@ public void validate(Object input) {
}

@Override
public List<ModelTensor> process(Map<String, String> jobInfo) {
public List<ModelTensor> process(Object input) {
Map<String, String> jobInfo = (Map<String, String>) input;
List<ModelTensor> modelTensors = new ArrayList<>();
Map<String, String> processedResult = new HashMap<>();
processedResult.putAll(jobInfo);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,122 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;

import org.opensearch.ml.common.CommonValue;
import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;

public class BedrockEmbeddingPostProcessFunction extends ConnectorPostProcessFunction<List<Float>> {
import com.google.common.collect.ImmutableMap;

/**
* Bedrock embedding post process function currently is used by bedrock titan models, for v1 model,
* the model response is a list of float numbers, for v2 model, the model response combined by two parts:
* 1. "embedding" which returns list of float numbers like v1.
* 2. "embeddingByType" is a map contains all embedding type results, with embedding type as the key.
*/
public class BedrockEmbeddingPostProcessFunction implements ConnectorPostProcessFunction {

@Override
public void validate(Object input) {
if (!(input instanceof List)) {
throw new IllegalArgumentException("Post process function input is not a List.");
if (input instanceof List<?>) {
validateEmbeddingList((List<?>) input);
} else if (input instanceof Map) {
for (Map.Entry<String, Object> entry : ((Map<String, Object>) input).entrySet()) {
if (!(entry.getValue() instanceof List)) {
throw new IllegalArgumentException(
String
.format(
Locale.ROOT,
"Model response embedding type %s result is NOT an list type, please check the model response!",
entry.getKey()
)
);
}
validateEmbeddingList((List<?>) entry.getValue());
}
} else {
throw new IllegalArgumentException("Model response is neither a list type nor a map type, please check the model response!");
}
}

List<?> outerList = (List<?>) input;

if (!outerList.isEmpty() && !(((List<?>) input).get(0) instanceof Number)) {
throw new IllegalArgumentException("The embedding should be a non-empty List containing Float values.");
/**
* The response could be list (case1: when specified concrete embedding type or case2: a v1 model specified with $.embedding)
* or map (case3: when specified embedding by type), but since the data type is not resolved, so consider this is case2 or case3.
* @param input the model's response: v1 model's embedding part or v2 model's embeddingByType part.
* @return List of ModelTensor that represent the embedding result including all different embedding types or single embedding type.
*/
@Override
public List<ModelTensor> process(Object input) {
List<ModelTensor> modelTensors = new ArrayList<>();
if (input instanceof Map) {
modelTensors
.add(
ModelTensor
.builder()
.name(CommonValue.ML_MAP_RESPONSE_KEY)
.dataAsMap(ImmutableMap.of(CommonValue.ML_MAP_RESPONSE_KEY, input))
.build()
);
} else {
List<Float> embedding = (List<Float>) input;
modelTensors
.add(
ModelTensor
.builder()
.name("sentence_embedding")
.dataType(MLResultDataType.FLOAT32)
.shape(new long[] { embedding.size() })
.data(embedding.toArray(new Number[0]))
.build()
);
}
return modelTensors;
}

/**
* When the response is map, it means user specifies the response filter to a concrete embedding type, e.g.: $.embeddingByType.float
* In this case we need to process the result to ModelTensor's data field as it's same as before. If user specifies the response
* filter to embedding, e.g. $.embedding, then we need to convert the result to ModelTensor's dataAsMap field as the result is a map.
* @param input Model's response or extracted object from the model response by response filter.
* @param mlResultDataType The data type of the model's response.
* @return List of ModelTensor that represent the embedding result including all different embedding types or single embedding type.
*/
@Override
public List<ModelTensor> process(List<Float> embedding) {
public List<ModelTensor> process(Object input, MLResultDataType mlResultDataType) {
List<ModelTensor> modelTensors = new ArrayList<>();
modelTensors
.add(
ModelTensor
.builder()
.name("sentence_embedding")
.dataType(MLResultDataType.FLOAT32)
.shape(new long[] { embedding.size() })
.data(embedding.toArray(new Number[0]))
.build()
);
if (input instanceof Map) {
modelTensors
.add(
ModelTensor
.builder()
.name(CommonValue.ML_MAP_RESPONSE_KEY)
.dataAsMap(ImmutableMap.of(CommonValue.ML_MAP_RESPONSE_KEY, input))
.build()
);

} else if (input instanceof List) {
List<Number> embedding = (List<Number>) input;
modelTensors
.add(
ModelTensor
.builder()
.name("sentence_embedding")
.dataType(mlResultDataType)
.shape(new long[] { embedding.size() })
.data(embedding.toArray(new Number[0]))
.build()
);
}
return modelTensors;
}

private void validateEmbeddingList(List<?> input) {
if (input.isEmpty() || !(input.get(0) instanceof Number)) {
throw new IllegalArgumentException(
"Model result is NOT an non-empty List containing Number values, please check the model response!"
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;

public class CohereRerankPostProcessFunction extends ConnectorPostProcessFunction<List<Map<String, Object>>> {
public class CohereRerankPostProcessFunction implements ConnectorPostProcessFunction {

@Override
public void validate(Object input) {
Expand All @@ -33,7 +33,8 @@ public void validate(Object input) {
}

@Override
public List<ModelTensor> process(List<Map<String, Object>> rerankResults) {
public List<ModelTensor> process(Object input) {
List<Map<String, Object>> rerankResults = (List<Map<String, Object>>) input;
List<ModelTensor> modelTensors = new ArrayList<>();

if (rerankResults.size() > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,37 @@
package org.opensearch.ml.common.connector.functions.postprocess;

import java.util.List;
import java.util.function.Function;

import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;

public abstract class ConnectorPostProcessFunction<T> implements Function<Object, List<ModelTensor>> {
public interface ConnectorPostProcessFunction {

@Override
public List<ModelTensor> apply(Object input) {
default List<ModelTensor> apply(Object input) {
if (input == null) {
throw new IllegalArgumentException("Can't run post process function as model output is null");
}
validate(input);
return process((T) input);
return process(input);
}

public abstract void validate(Object input);
default List<ModelTensor> apply(Object input, MLResultDataType dataType) {
if (input == null) {
throw new IllegalArgumentException("Can't run post process function as model output is null");
}
validate(input);
return process(input, dataType);
}

void validate(Object input);

public abstract List<ModelTensor> process(T input);
List<ModelTensor> process(Object input);

default List<ModelTensor> process(Object input, MLResultDataType dataType) {
throw new IllegalArgumentException(
"The post process function is not expected to run unless your model is a embedding type supported model"
+ " and the response_filter configuration in connector been set to an embedding type path, please check "
+ "connector.post_process.default.embedding for more information"
);
}
}
Loading

0 comments on commit 58dd965

Please sign in to comment.