From e1fe5723773c947604c86124eeea64b118f30aba Mon Sep 17 00:00:00 2001 From: zane-neo Date: Fri, 26 Jan 2024 14:34:19 +0800 Subject: [PATCH] Merge two functions together Signed-off-by: zane-neo --- .../connector/MLPostProcessFunction.java | 33 +++++-------------- .../connector/MLPostProcessFunctionTest.java | 6 ++-- 2 files changed, 12 insertions(+), 27 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java index 2eb9d33ad2..c07c83b636 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java @@ -5,6 +5,7 @@ package org.opensearch.ml.common.connector; +import com.google.common.collect.ImmutableList; import org.opensearch.ml.common.output.model.MLResultDataType; import org.opensearch.ml.common.output.model.ModelTensor; @@ -31,37 +32,21 @@ public class MLPostProcessFunction { JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings"); JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]"); JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding"); - POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, buildMultipleResultModelTensor()); - POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, buildMultipleResultModelTensor()); - POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, buildMultipleResultModelTensor()); - POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, buildSingleResultModelTensor()); + POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, buildModelTensorResult()); + POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, buildModelTensorResult()); + POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, buildModelTensorResult()); + POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, buildModelTensorResult()); } - public static Function, List> buildSingleResultModelTensor() { - return embedding -> { - List modelTensors = new ArrayList<>(); - if (embedding == null) { - throw new IllegalArgumentException("The embedding is null when using the built-in post-processing function."); - } - modelTensors.add( - ModelTensor - .builder() - .name("sentence_embedding") - .dataType(MLResultDataType.FLOAT32) - .shape(new long[]{embedding.size()}) - .data(embedding.toArray(new Number[0])) - .build() - ); - return modelTensors; - }; - } - - public static Function, List> buildMultipleResultModelTensor() { + public static Function, List> buildModelTensorResult() { return embeddings -> { List modelTensors = new ArrayList<>(); if (embeddings == null) { throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function."); } + if (embeddings.get(0) instanceof Number) { + embeddings = ImmutableList.of(embeddings); + } embeddings.forEach(embedding -> { List eachEmbedding = (List) embedding; modelTensors.add( diff --git a/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java index 1fbbaa27ba..b1902832ce 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java @@ -41,16 +41,16 @@ public void test_getResponseFilter() { @Test public void test_buildMultipleResultModelTensorList() { - Assert.assertNotNull(MLPostProcessFunction.buildMultipleResultModelTensor()); + Assert.assertNotNull(MLPostProcessFunction.buildModelTensorResult()); List> numbersList = new ArrayList<>(); numbersList.add(Collections.singletonList(1.0f)); - Assert.assertNotNull(MLPostProcessFunction.buildMultipleResultModelTensor().apply(numbersList)); + Assert.assertNotNull(MLPostProcessFunction.buildModelTensorResult().apply(numbersList)); } @Test public void test_buildMultipleResultModelTensorList_exception() { exceptionRule.expect(IllegalArgumentException.class); - MLPostProcessFunction.buildMultipleResultModelTensor().apply(null); + MLPostProcessFunction.buildModelTensorResult().apply(null); } @Test