Skip to content

Commit

Permalink
Merge two functions together
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Jan 26, 2024
1 parent f90e8f3 commit e1fe572
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.common.connector;

import com.google.common.collect.ImmutableList;
import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;

Expand All @@ -31,37 +32,21 @@ public class MLPostProcessFunction {
JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings");
JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]");
JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding");
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, buildMultipleResultModelTensor());
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, buildMultipleResultModelTensor());
POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, buildMultipleResultModelTensor());
POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, buildSingleResultModelTensor());
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, buildModelTensorResult());
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, buildModelTensorResult());
POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, buildModelTensorResult());
POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, buildModelTensorResult());
}

public static Function<List<?>, List<ModelTensor>> buildSingleResultModelTensor() {
return embedding -> {
List<ModelTensor> modelTensors = new ArrayList<>();
if (embedding == null) {
throw new IllegalArgumentException("The embedding is null when using the built-in post-processing function.");
}
modelTensors.add(
ModelTensor
.builder()
.name("sentence_embedding")
.dataType(MLResultDataType.FLOAT32)
.shape(new long[]{embedding.size()})
.data(embedding.toArray(new Number[0]))
.build()
);
return modelTensors;
};
}

public static Function<List<?>, List<ModelTensor>> buildMultipleResultModelTensor() {
public static Function<List<?>, List<ModelTensor>> buildModelTensorResult() {
return embeddings -> {
List<ModelTensor> modelTensors = new ArrayList<>();
if (embeddings == null) {
throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function.");
}
if (embeddings.get(0) instanceof Number) {
embeddings = ImmutableList.of(embeddings);
}
embeddings.forEach(embedding -> {
List<Number> eachEmbedding = (List<Number>) embedding;
modelTensors.add(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,16 @@ public void test_getResponseFilter() {

@Test
public void test_buildMultipleResultModelTensorList() {
Assert.assertNotNull(MLPostProcessFunction.buildMultipleResultModelTensor());
Assert.assertNotNull(MLPostProcessFunction.buildModelTensorResult());
List<List<Float>> numbersList = new ArrayList<>();
numbersList.add(Collections.singletonList(1.0f));
Assert.assertNotNull(MLPostProcessFunction.buildMultipleResultModelTensor().apply(numbersList));
Assert.assertNotNull(MLPostProcessFunction.buildModelTensorResult().apply(numbersList));
}

@Test
public void test_buildMultipleResultModelTensorList_exception() {
exceptionRule.expect(IllegalArgumentException.class);
MLPostProcessFunction.buildMultipleResultModelTensor().apply(null);
MLPostProcessFunction.buildModelTensorResult().apply(null);
}

@Test
Expand Down

0 comments on commit e1fe572

Please sign in to comment.