diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java index abe56cde0e..7981f08175 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java @@ -6,15 +6,13 @@ package org.opensearch.ml.common.connector; import java.util.HashMap; -import java.util.List; import java.util.Map; -import java.util.function.Function; import org.opensearch.ml.common.connector.functions.postprocess.BedrockBatchJobArnPostProcessFunction; import org.opensearch.ml.common.connector.functions.postprocess.BedrockEmbeddingPostProcessFunction; import org.opensearch.ml.common.connector.functions.postprocess.CohereRerankPostProcessFunction; +import org.opensearch.ml.common.connector.functions.postprocess.ConnectorPostProcessFunction; import org.opensearch.ml.common.connector.functions.postprocess.EmbeddingPostProcessFunction; -import org.opensearch.ml.common.output.model.ModelTensor; public class MLPostProcessFunction { @@ -28,7 +26,7 @@ public class MLPostProcessFunction { private static final Map JSON_PATH_EXPRESSION = new HashMap<>(); - private static final Map>> POST_PROCESS_FUNCTIONS = new HashMap<>(); + private static final Map POST_PROCESS_FUNCTIONS = new HashMap<>(); static { EmbeddingPostProcessFunction embeddingPostProcessFunction = new EmbeddingPostProcessFunction(); @@ -55,7 +53,7 @@ public static String getResponseFilter(String postProcessFunction) { return JSON_PATH_EXPRESSION.get(postProcessFunction); } - public static Function> get(String postProcessFunction) { + public static ConnectorPostProcessFunction get(String postProcessFunction) { return POST_PROCESS_FUNCTIONS.get(postProcessFunction); } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockBatchJobArnPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockBatchJobArnPostProcessFunction.java index e69829855e..b556f6c5aa 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockBatchJobArnPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockBatchJobArnPostProcessFunction.java @@ -12,7 +12,7 @@ import org.opensearch.ml.common.output.model.ModelTensor; -public class BedrockBatchJobArnPostProcessFunction extends ConnectorPostProcessFunction> { +public class BedrockBatchJobArnPostProcessFunction implements ConnectorPostProcessFunction { public static final String JOB_ARN = "jobArn"; public static final String PROCESSED_JOB_ARN = "processedJobArn"; @@ -28,7 +28,8 @@ public void validate(Object input) { } @Override - public List process(Map jobInfo) { + public List process(Object input) { + Map jobInfo = (Map) input; List modelTensors = new ArrayList<>(); Map processedResult = new HashMap<>(); processedResult.putAll(jobInfo); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunction.java index 82823187e8..85f43b65ab 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunction.java @@ -7,38 +7,122 @@ import java.util.ArrayList; import java.util.List; +import java.util.Locale; +import java.util.Map; +import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.output.model.MLResultDataType; import org.opensearch.ml.common.output.model.ModelTensor; -public class BedrockEmbeddingPostProcessFunction extends ConnectorPostProcessFunction> { +import com.google.common.collect.ImmutableMap; + +/** + * Bedrock embedding post process function currently is used by bedrock titan models, for v1 model, + * the model response is a list of float numbers, for v2 model, the model response combined by two parts: + * 1. "embedding" which returns list of float numbers like v1. + * 2. "embeddingByType" is a map contains all embedding type results, with embedding type as the key. + */ +public class BedrockEmbeddingPostProcessFunction implements ConnectorPostProcessFunction { @Override public void validate(Object input) { - if (!(input instanceof List)) { - throw new IllegalArgumentException("Post process function input is not a List."); + if (input instanceof List) { + validateEmbeddingList((List) input); + } else if (input instanceof Map) { + for (Map.Entry entry : ((Map) input).entrySet()) { + if (!(entry.getValue() instanceof List)) { + throw new IllegalArgumentException( + String + .format( + Locale.ROOT, + "Model response embedding type %s result is NOT an list type, please check the model response!", + entry.getKey() + ) + ); + } + validateEmbeddingList((List) entry.getValue()); + } + } else { + throw new IllegalArgumentException("Model response is neither a list type nor a map type, please check the model response!"); } + } - List outerList = (List) input; - - if (!outerList.isEmpty() && !(((List) input).get(0) instanceof Number)) { - throw new IllegalArgumentException("The embedding should be a non-empty List containing Float values."); + /** + * The response could be list (case1: when specified concrete embedding type or case2: a v1 model specified with $.embedding) + * or map (case3: when specified embedding by type), but since the data type is not resolved, so consider this is case2 or case3. + * @param input the model's response: v1 model's embedding part or v2 model's embeddingByType part. + * @return List of ModelTensor that represent the embedding result including all different embedding types or single embedding type. + */ + @Override + public List process(Object input) { + List modelTensors = new ArrayList<>(); + if (input instanceof Map) { + modelTensors + .add( + ModelTensor + .builder() + .name(CommonValue.ML_MAP_RESPONSE_KEY) + .dataAsMap(ImmutableMap.of(CommonValue.ML_MAP_RESPONSE_KEY, input)) + .build() + ); + } else { + List embedding = (List) input; + modelTensors + .add( + ModelTensor + .builder() + .name("sentence_embedding") + .dataType(MLResultDataType.FLOAT32) + .shape(new long[] { embedding.size() }) + .data(embedding.toArray(new Number[0])) + .build() + ); } + return modelTensors; } + /** + * When the response is map, it means user specifies the response filter to a concrete embedding type, e.g.: $.embeddingByType.float + * In this case we need to process the result to ModelTensor's data field as it's same as before. If user specifies the response + * filter to embedding, e.g. $.embedding, then we need to convert the result to ModelTensor's dataAsMap field as the result is a map. + * @param input Model's response or extracted object from the model response by response filter. + * @param mlResultDataType The data type of the model's response. + * @return List of ModelTensor that represent the embedding result including all different embedding types or single embedding type. + */ @Override - public List process(List embedding) { + public List process(Object input, MLResultDataType mlResultDataType) { List modelTensors = new ArrayList<>(); - modelTensors - .add( - ModelTensor - .builder() - .name("sentence_embedding") - .dataType(MLResultDataType.FLOAT32) - .shape(new long[] { embedding.size() }) - .data(embedding.toArray(new Number[0])) - .build() - ); + if (input instanceof Map) { + modelTensors + .add( + ModelTensor + .builder() + .name(CommonValue.ML_MAP_RESPONSE_KEY) + .dataAsMap(ImmutableMap.of(CommonValue.ML_MAP_RESPONSE_KEY, input)) + .build() + ); + + } else if (input instanceof List) { + List embedding = (List) input; + modelTensors + .add( + ModelTensor + .builder() + .name("sentence_embedding") + .dataType(mlResultDataType) + .shape(new long[] { embedding.size() }) + .data(embedding.toArray(new Number[0])) + .build() + ); + } return modelTensors; } + + private void validateEmbeddingList(List input) { + if (input.isEmpty() || !(input.get(0) instanceof Number)) { + throw new IllegalArgumentException( + "Model result is NOT an non-empty List containing Number values, please check the model response!" + ); + } + } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunction.java index cf93202366..d9adf10b8e 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunction.java @@ -12,7 +12,7 @@ import org.opensearch.ml.common.output.model.MLResultDataType; import org.opensearch.ml.common.output.model.ModelTensor; -public class CohereRerankPostProcessFunction extends ConnectorPostProcessFunction>> { +public class CohereRerankPostProcessFunction implements ConnectorPostProcessFunction { @Override public void validate(Object input) { @@ -33,7 +33,8 @@ public void validate(Object input) { } @Override - public List process(List> rerankResults) { + public List process(Object input) { + List> rerankResults = (List>) input; List modelTensors = new ArrayList<>(); if (rerankResults.size() > 0) { diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/ConnectorPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/ConnectorPostProcessFunction.java index a5374a42bb..fc68a3b889 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/ConnectorPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/ConnectorPostProcessFunction.java @@ -6,22 +6,37 @@ package org.opensearch.ml.common.connector.functions.postprocess; import java.util.List; -import java.util.function.Function; +import org.opensearch.ml.common.output.model.MLResultDataType; import org.opensearch.ml.common.output.model.ModelTensor; -public abstract class ConnectorPostProcessFunction implements Function> { +public interface ConnectorPostProcessFunction { - @Override - public List apply(Object input) { + default List apply(Object input) { if (input == null) { throw new IllegalArgumentException("Can't run post process function as model output is null"); } validate(input); - return process((T) input); + return process(input); } - public abstract void validate(Object input); + default List apply(Object input, MLResultDataType dataType) { + if (input == null) { + throw new IllegalArgumentException("Can't run post process function as model output is null"); + } + validate(input); + return process(input, dataType); + } + + void validate(Object input); - public abstract List process(T input); + List process(Object input); + + default List process(Object input, MLResultDataType dataType) { + throw new IllegalArgumentException( + "The post process function is not expected to run unless your model is a embedding type supported model" + + " and the response_filter configuration in connector been set to an embedding type path, please check " + + "connector.post_process.default.embedding for more information" + ); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunction.java index e3142b8368..5f9fdb50d1 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunction.java @@ -7,48 +7,166 @@ import java.util.ArrayList; import java.util.List; +import java.util.Locale; +import java.util.Map; +import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.output.model.MLResultDataType; import org.opensearch.ml.common.output.model.ModelTensor; -public class EmbeddingPostProcessFunction extends ConnectorPostProcessFunction>> { +import com.google.common.collect.ImmutableMap; +/** + * This is the default embedding post process function, the expected result from the model in two cases: + * 1. A list of list of float for APIs embedding type is not enabled. + * 2. A map of string to list of list of number for APIs that embedding type is enabled. + * An example of enabled embedding type is cohere v2, the embedding API requires embedding_types as a mandatory field in the request body: + * ... + * Currently, OpenAI Cohere and default sagemaker embedding models are using this function. + */ +public class EmbeddingPostProcessFunction implements ConnectorPostProcessFunction { + + /** + * Validate the model's output is a List of following types: + * float + * int8: a signed number with eight bit, the value range is [-128, 127]. + * uint8: an unsigned number with eight bit, the value range is [0, 255]. + * binary: a binary representation of the embedding, the value range is non-deterministic. + * ubinary: a binary representation of the embedding, the value range is non-deterministic. + * @param input The input is the output from the model. + */ @Override public void validate(Object input) { - if (!(input instanceof List)) { - throw new IllegalArgumentException("Post process function input is not a List."); + if (input instanceof List) { + validateEmbeddingList((List) input); + } else if (input instanceof Map) { + for (Map.Entry entry : ((Map) input).entrySet()) { + if (!(entry.getValue() instanceof List)) { + throw new IllegalArgumentException( + String + .format( + Locale.ROOT, + "Model response embedding type %s result is NOT an list type, please check the model response!", + entry.getKey() + ) + ); + } + validateEmbeddingList((List) entry.getValue()); + } + } else { + throw new IllegalArgumentException("Model response is neither a list type nor a map type, please check the model response!"); } + } - List outerList = (List) input; + /** + * The response could be list (case1: when specified concrete embedding type or case2: a v1 model specified with $.embeddings) + * or map (case3: when specified embeddings in v2 model), but since the data type is not resolved, so consider this is case2 or case3. + * v1 model's embeddings part is a list and v2 is a map. + * @param input the model's response: v1 model's embeddings part or v2 model's embeddings part. + * @return List of ModelTensor that represent the embedding result including all different embedding types or single embedding type. + */ - if (!outerList.isEmpty()) { - if (!(outerList.get(0) instanceof List)) { - throw new IllegalArgumentException("The embedding should be a non-empty List containing List of Float values."); - } - List innerList = (List) outerList.get(0); + @Override + public List process(Object input) { + List modelTensors = new ArrayList<>(); + if (input instanceof Map) { + modelTensors + .add( + ModelTensor + .builder() + .name(CommonValue.ML_MAP_RESPONSE_KEY) + .dataAsMap(ImmutableMap.of(CommonValue.ML_MAP_RESPONSE_KEY, input)) + .build() + ); + } else { + List> embeddings = (List>) input; + embeddings + .forEach( + embedding -> modelTensors + .add( + ModelTensor + .builder() + .name("sentence_embedding") + .dataType(MLResultDataType.FLOAT32) + .shape(new long[] { embedding.size() }) + .data(embedding.toArray(new Number[0])) + .build() + ) + ); + } + return modelTensors; + } - if (innerList.isEmpty() || !(innerList.get(0) instanceof Number)) { - throw new IllegalArgumentException("The embedding should be a non-empty List containing Float values."); - } + // List> result + private void validateEmbeddingList(List outerList) { + if (outerList.isEmpty() || !(outerList.get(0) instanceof List)) { + throw new IllegalArgumentException( + "Model result is NOT an non-empty List containing List values, please check the model response!" + ); + } + List innerList = (List) outerList.get(0); + if (innerList.isEmpty() || !(innerList.get(0) instanceof Number)) { + throw new IllegalArgumentException( + "Model result is NOT an non-empty List containing List of Number values, please check the model response!" + ); } } + /** + * As in connector user can configure response filter to extract the result from raw model response, so we need to support different + * cases, take cohere model as an example, the raw result looks like below: + * { + * ...... + * "embeddings": { + * "float": [ + * [ + * -0.007247925, + * -0.041229248, + * -0.023223877 + * ...... + * ] + * ], + * "int8": [ + * 1, + * 2, + * 3 + * ] + * }, + * ...... + * } + * 1. When response filter is set to: $.embeddings.float, then the result is a list of list. + * 2. When response filter is set to: $.embeddings, then the result is a map of embedding type to list of list. + * 3. When response filter is not set which is the default case, and the result is same with case2. + * @param modelOutput the embedding result of embedding type supported models. + * @return List of ModelTensor that represent the embedding result including all different embedding types. + */ @Override - public List process(List> embeddings) { + public List process(Object modelOutput, MLResultDataType mlResultDataType) { List modelTensors = new ArrayList<>(); - embeddings - .forEach( - embedding -> modelTensors + if (modelOutput instanceof Map) { + modelTensors + .add( + ModelTensor + .builder() + .name(CommonValue.ML_MAP_RESPONSE_KEY) + .dataAsMap(ImmutableMap.of(CommonValue.ML_MAP_RESPONSE_KEY, modelOutput)) + .build() + ); + } else if (modelOutput instanceof List) { + for (Object element : (List) modelOutput) { + List singleEmbedding = (List) element; + modelTensors .add( ModelTensor .builder() - .name("sentence_embedding") - .dataType(MLResultDataType.FLOAT32) - .shape(new long[] { embedding.size() }) - .data(embedding.toArray(new Number[0])) + .name("embedding") + .shape(new long[] { singleEmbedding.size() }) + .data(singleEmbedding.toArray(new Number[0])) + .dataType(mlResultDataType) .build() - ) - ); + ); + } + } return modelTensors; } } diff --git a/common/src/main/java/org/opensearch/ml/common/output/model/MLResultDataType.java b/common/src/main/java/org/opensearch/ml/common/output/model/MLResultDataType.java index 660213bc17..2edf9bde48 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/model/MLResultDataType.java +++ b/common/src/main/java/org/opensearch/ml/common/output/model/MLResultDataType.java @@ -15,7 +15,9 @@ public enum MLResultDataType { INT64(Format.INT, 8), BOOLEAN(Format.BOOLEAN, 1), UNKNOWN(Format.UNKNOWN, 0), - STRING(Format.STRING, -1); + STRING(Format.STRING, -1), + BINARY(Format.BINARY, -1), + UBINARY(Format.UBINARY, -1); /** The general data type format categories. */ public enum Format { @@ -24,6 +26,8 @@ public enum Format { INT, BOOLEAN, STRING, + BINARY, + UBINARY, UNKNOWN } @@ -70,4 +74,20 @@ public boolean isBoolean() { public boolean isString() { return format == Format.STRING; } + + /** + * Checks whether it is a binary data type. + * @return true if is a binary type, otherwise false + */ + public boolean isBinary() { + return format == Format.BINARY; + } + + /** + * Checks whether it is a ubinary data type. + * @return true if is a ubinary type, otherwise false + */ + public boolean isUbinary() { + return format == Format.UBINARY; + } } diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunctionTest.java index 5a455e0e4b..b966f12bf2 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunctionTest.java @@ -30,14 +30,14 @@ public void setUp() { @Test public void process_WrongInput_NotList() { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Post process function input is not a List."); + exceptionRule.expectMessage("Model response is neither a list type nor a map type, please check the model response!"); function.apply("abc"); } @Test public void process_WrongInput_NotNumberList() { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("The embedding should be a non-empty List containing Float values."); + exceptionRule.expectMessage("Model result is NOT an non-empty List containing Number values, please check the model response!"); function.apply(Arrays.asList("abc")); } diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunctionTest.java index b2abf33216..bb6fcc9083 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunctionTest.java @@ -30,21 +30,22 @@ public void setUp() { @Test public void process_WrongInput_NotList() { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Post process function input is not a List."); + exceptionRule.expectMessage("Model response is neither a list type nor a map type, please check the model response!"); function.apply("abc"); } @Test public void process_WrongInput_NotListOfList() { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("The embedding should be a non-empty List containing List of Float values."); + exceptionRule.expectMessage("Model result is NOT an non-empty List containing List values, please check the model response!"); function.apply(Arrays.asList("abc")); } @Test public void process_WrongInput_NotListOfNumber() { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("The embedding should be a non-empty List containing Float values."); + exceptionRule + .expectMessage("Model result is NOT an non-empty List containing List of Number values, please check the model response!"); function.apply(List.of(Arrays.asList("abc"))); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index f2c93ef5fd..6318f0a313 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -23,6 +23,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Optional; import java.util.function.Function; @@ -40,6 +41,7 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.model.MLGuard; +import org.opensearch.ml.common.output.model.MLResultDataType; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.script.ScriptService; @@ -225,7 +227,15 @@ public static ModelTensors processOutput( responseFilter = MLPostProcessFunction.getResponseFilter(postProcessFunction); Object filteredOutput = JsonPath.read(modelResponse, responseFilter); - List processedResponse = MLPostProcessFunction.get(postProcessFunction).apply(filteredOutput); + // For case use specifies the response filter to embedding type, we need to set the embedding type back to processed + // ModelTensor. + MLResultDataType mlResultDataType = parseMLResultDataTypeFromResponseFilter(responseFilter); + List processedResponse; + if (mlResultDataType == null) { + processedResponse = MLPostProcessFunction.get(postProcessFunction).apply(filteredOutput); + } else { + processedResponse = MLPostProcessFunction.get(postProcessFunction).apply(filteredOutput, mlResultDataType); + } return ModelTensors.builder().mlModelTensors(processedResponse).build(); } @@ -244,6 +254,15 @@ public static ModelTensors processOutput( return ModelTensors.builder().mlModelTensors(modelTensors).build(); } + private static MLResultDataType parseMLResultDataTypeFromResponseFilter(String responseFilter) { + for (MLResultDataType type : MLResultDataType.values()) { + if (responseFilter.contains("." + type.name()) || responseFilter.contains("." + type.name().toLowerCase(Locale.ROOT))) { + return type; + } + } + return null; + } + private static String fillProcessFunctionParameter(Map parameters, String processFunction) { if (processFunction != null && processFunction.contains("${parameters.")) { Map tmpParameters = new HashMap<>(); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index 286d45d308..4ddc3bf2d4 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -242,6 +242,91 @@ public void test_bedrock_multimodal_model_empty_imageInput_null_textInput() thro } } + public void test_bedrockEmbeddingTypeSupportedModel_withDifferentResponseFilters() throws Exception { + if (tokenNotSet()) { + return; + } + String templates = Files + .readString( + Path + .of( + RestMLPredictionAction.class + .getClassLoader() + .getResource("org/opensearch/ml/rest/templates/BedRockEmbeddingTypeSupportedConnectorBodies.json") + .toURI() + ) + ); + Map templateMap = StringUtils.gson.fromJson(templates, Map.class); + for (Map.Entry templateEntry : templateMap.entrySet()) { + String bedrockEmbeddingModelName = "embedding type supported model " + randomAlphaOfLength(5); + String testCaseName = templateEntry.getKey(); + String modelId = registerRemoteModel( + String + .format( + StringUtils.gson.toJson(templateEntry.getValue()), + GITHUB_CI_AWS_REGION, + AWS_ACCESS_KEY_ID, + AWS_SECRET_ACCESS_KEY, + AWS_SESSION_TOKEN + ), + bedrockEmbeddingModelName, + true + ); + + List input = new ArrayList<>(); + input.add("hello world"); + TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(input).build(); + MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); + Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); + String errorMsg = String + .format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult)); + + // when response filter is to: embedding, and the request embedding types have multiple values, the ModelTensor's data is not + // null. + if (testCaseName.equals("response_filter_to_embedding")) { + assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); + List output = (List) inferenceResult.get("inference_results"); + assertEquals(errorMsg, 1, output.size()); + assertTrue(errorMsg, output.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) output.get(0)).get("output") instanceof List); + List outputList = (List) ((Map) output.get(0)).get("output"); + assertEquals(errorMsg, 1, outputList.size()); + assertTrue(errorMsg, outputList.get(0) instanceof Map); + assertEquals(errorMsg, 1024, ((List) ((Map) outputList.get(0)).get("data")).size()); + } else if (testCaseName.equals("response_filter_to_embedding_by_type")) { + // when response filter is embedding_by_type, then the result should be in ModelTensor's dataAsMap field. + assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); + List output = (List) inferenceResult.get("inference_results"); + assertEquals(errorMsg, 1, output.size()); + assertTrue(errorMsg, output.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) output.get(0)).get("output") instanceof List); + List outputList = (List) ((Map) output.get(0)).get("output"); + assertEquals(errorMsg, 1, outputList.size()); + assertTrue(errorMsg, outputList.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) outputList.get(0)).get("dataAsMap") instanceof Map); + Map dataAsMap = (Map) ((Map) ((Map) outputList.get(0)).get("dataAsMap")).get("response"); + assertTrue(errorMsg, dataAsMap.containsKey("float") && dataAsMap.containsKey("binary")); + } else if (testCaseName.equals("response_filter_to_embedding_concrete_type") + || testCaseName.equals("response_filter_not_set")) { + // when response filter is to: concrete embedding type or (not set and the request embedding types have only one value), + // the ModelTensor's data is not null. + assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); + List output = (List) inferenceResult.get("inference_results"); + assertEquals(errorMsg, 1, output.size()); + assertTrue(errorMsg, output.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) output.get(0)).get("output") instanceof List); + List outputList = (List) ((Map) output.get(0)).get("output"); + assertEquals(errorMsg, 1, outputList.size()); + assertTrue(errorMsg, outputList.get(0) instanceof Map); + if (testCaseName.equals("response_filter_to_embedding_concrete_type")) { + assertEquals(errorMsg, 1024, ((List) ((Map) outputList.get(0)).get("data")).size()); + } else { + assertEquals(errorMsg, 1536, ((List) ((Map) outputList.get(0)).get("data")).size()); + } + } + } + } + private boolean tokenNotSet() { if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { log.info("#### The AWS credentials are not set. Skipping test. ####"); diff --git a/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockEmbeddingTypeSupportedConnectorBodies.json b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockEmbeddingTypeSupportedConnectorBodies.json new file mode 100644 index 0000000000..bee9d60135 --- /dev/null +++ b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockEmbeddingTypeSupportedConnectorBodies.json @@ -0,0 +1,125 @@ +{ + "response_filter_to_embedding": { + "name": "Amazon Bedrock Connector with response filter to embedding", + "description": "Amazon Bedrock Connector with response filter to embedding", + "version": 1, + "protocol": "aws_sigv4", + "parameters": { + "region": "%s", + "service_name": "bedrock", + "model_name": "amazon.titan-embed-text-v2:0", + "response_filter": "$.embedding" + }, + "credential": { + "access_key": "%s", + "secret_key": "%s", + "session_token": "%s" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", + "headers": { + "content-type": "application/json", + "x-amz-content-sha256": "required" + }, + "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"embeddingTypes\": [\"float\", \"binary\"] }", + "pre_process_function": "connector.pre_process.bedrock.embedding", + "post_process_function": "connector.post_process.bedrock.embedding" + } + ] + }, + "response_filter_to_embedding_by_type": { + "name": "Amazon Bedrock Connector with response filter to embeddingByType", + "description": "Amazon Bedrock Connector with response filter to embeddingByType", + "version": 1, + "protocol": "aws_sigv4", + "parameters": { + "region": "%s", + "service_name": "bedrock", + "model_name": "amazon.titan-embed-text-v2:0", + "response_filter": "$.embeddingsByType" + }, + "credential": { + "access_key": "%s", + "secret_key": "%s", + "session_token": "%s" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", + "headers": { + "content-type": "application/json", + "x-amz-content-sha256": "required" + }, + "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"embeddingTypes\": [\"float\", \"binary\"] }", + "pre_process_function": "connector.pre_process.bedrock.embedding", + "post_process_function": "connector.post_process.bedrock.embedding" + } + ] + }, + "response_filter_to_embedding_concrete_type": { + "name": "Amazon Bedrock Connector with response filter set to concrete embedding type", + "description": "Amazon Bedrock Connector with response filter set to concrete embedding type", + "version": 1, + "protocol": "aws_sigv4", + "parameters": { + "region": "%s", + "service_name": "bedrock", + "model_name": "amazon.titan-embed-text-v2:0", + "response_filter": "$.embeddingsByType.float" + }, + "credential": { + "access_key": "%s", + "secret_key": "%s", + "session_token": "%s" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", + "headers": { + "content-type": "application/json", + "x-amz-content-sha256": "required" + }, + "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"embeddingTypes\": [\"float\", \"binary\"] }", + "pre_process_function": "connector.pre_process.bedrock.embedding", + "post_process_function": "connector.post_process.bedrock.embedding" + } + ] + }, + "response_filter_not_set": { + "name": "Amazon Bedrock Connector with response filter not set for v1 model", + "description": "Amazon Bedrock Connector with response filter not set for v1 model", + "version": 1, + "protocol": "aws_sigv4", + "parameters": { + "region": "%s", + "service_name": "bedrock", + "model_name": "amazon.titan-embed-text-v1" + }, + "credential": { + "access_key": "%s", + "secret_key": "%s", + "session_token": "%s" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", + "headers": { + "content-type": "application/json", + "x-amz-content-sha256": "required" + }, + "request_body": "{ \"inputText\": \"${parameters.inputText}\"}", + "pre_process_function": "connector.pre_process.bedrock.embedding", + "post_process_function": "connector.post_process.bedrock.embedding" + } + ] + } +}