Skip to content

Commit

Permalink
add more tests
Browse files Browse the repository at this point in the history
Signed-off-by: Mingshi Liu <[email protected]>
  • Loading branch information
mingshl committed Oct 17, 2024
1 parent ccf9c61 commit 81e6fe8
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,7 @@
import static org.opensearch.ml.processor.MLInferenceIngestProcessor.OVERRIDE;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;

import org.apache.logging.log4j.LogManager;
Expand Down Expand Up @@ -84,6 +78,8 @@ public class MLInferenceSearchResponseProcessor extends AbstractProcessor implem
// it can be overwritten using max_prediction_tasks when creating processor
public static final int DEFAULT_MAX_PREDICTION_TASKS = 10;
public static final String DEFAULT_OUTPUT_FIELD_NAME = "inference_results";
// allow to write to the extension of the search response, the path to point to search extension
// is prefix with ext.ml_inference
public static final String EXTENSION_PREFIX = "ext.ml_inference";

protected MLInferenceSearchResponseProcessor(
Expand Down Expand Up @@ -804,11 +800,12 @@ public MLInferenceSearchResponseProcessor create(
}
boolean writeToSearchExtension = false;

if (outputMaps != null
&& outputMaps
if (outputMaps != null) {
writeToSearchExtension = outputMaps
.stream()
.anyMatch(outputMap -> outputMap.keySet().stream().anyMatch(key -> key.startsWith(EXTENSION_PREFIX)))) {
writeToSearchExtension = true;
.filter(Objects::nonNull) // To avoid potential NullPointerExceptions from null outputMaps
.flatMap(outputMap -> outputMap.keySet().stream())
.anyMatch(key -> key.startsWith(EXTENSION_PREFIX));
}

if (writeToSearchExtension & oneToOne) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import org.apache.lucene.search.TotalHits;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.OpenSearchParseException;
Expand Down Expand Up @@ -87,6 +88,7 @@ public void setup() {
*
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseException() throws Exception {

MLInferenceSearchResponseProcessor responseProcessor = getMlInferenceSearchResponseProcessorSinglePairMapping(
Expand Down Expand Up @@ -115,6 +117,7 @@ public void testProcessResponseException() throws Exception {
*
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseSuccess() throws Exception {
String modelInputField = "inputs";
String originalDocumentField = "text";
Expand Down Expand Up @@ -174,6 +177,7 @@ public void onFailure(Exception e) {
* with one to one prediction, 5 documents in hits are calling 5 prediction tasks
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseOneToOneWithCustomPrompt() throws Exception {

String newDocumentField = "context";
Expand Down Expand Up @@ -263,6 +267,7 @@ public void onFailure(Exception e) {
* with many to one prediction, 5 documents in hits are calling 1 prediction tasks
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseManyToOneWithCustomPrompt() throws Exception {

String documentField = "text";
Expand Down Expand Up @@ -360,6 +365,7 @@ public void onFailure(Exception e) {
* with full response path false and no output mapping is provided
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseManyToOneWithCustomPromptFullResponsePathFalse() throws Exception {

String documentField = "text";
Expand Down Expand Up @@ -436,6 +442,7 @@ public void onFailure(Exception e) {
* with full response path true and no output mapping is provided
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseManyToOneWithCustomPromptFullResponsePathTrue() throws Exception {

String documentField = "text";
Expand Down Expand Up @@ -511,6 +518,7 @@ public void onFailure(Exception e) {
* with query extensions
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseSuccessWriteToExt() throws Exception {
String documentField = "text";
String modelInputField = "context";
Expand Down Expand Up @@ -586,6 +594,7 @@ public void onFailure(Exception e) {
* with one to one prediction, 5 documents in hits are calling 5 prediction tasks
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseOneToOneWithNoMappings() throws Exception {

MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor(
Expand Down Expand Up @@ -666,6 +675,7 @@ public void onFailure(Exception e) {
* with one to one prediction, 5 documents in hits are calling 5 prediction tasks
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseOneToOneWithEmptyMappings() throws Exception {
List<Map<String, String>> outputMap = new ArrayList<>();
List<Map<String, String>> inputMap = new ArrayList<>();
Expand Down Expand Up @@ -747,6 +757,7 @@ public void onFailure(Exception e) {
* with one to one prediction, 5 documents in hits are calling 5 prediction tasks
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseOneToOneWithOutputMappings() throws Exception {

String newDocumentField = "text_embedding";
Expand Down Expand Up @@ -836,6 +847,7 @@ public void onFailure(Exception e) {
* when there is one document, the combinedResponseListener calls onFailure
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseOneToOneWithOutputMappingsCombineResponseListenerFail() throws Exception {

String newDocumentField = "text_embedding";
Expand Down Expand Up @@ -897,6 +909,7 @@ public void onFailure(Exception e) {
* when there is one document, the combinedResponseListener calls onFailure
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseOneToOneWithOutputMappingsCombineResponseListenerException() throws Exception {

String newDocumentField = "text_embedding";
Expand Down Expand Up @@ -953,6 +966,7 @@ public void onFailure(Exception e) {
* when there is one document and ignoreFailure, should return the original response
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseOneToOneWithOutputMappingsCombineResponseListenerExceptionIgnoreFailure() throws Exception {

String newDocumentField = "text_embedding";
Expand Down Expand Up @@ -1009,6 +1023,7 @@ public void onFailure(Exception e) {
* when there is one document and ignoreFailure, should return the original response
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseCreateRewriteResponseListenerExceptionIgnoreFailure() throws Exception {

String newDocumentField = "text_embedding";
Expand Down Expand Up @@ -1099,6 +1114,7 @@ public void onFailure(Exception e) {
* createRewriteResponseListener should reach on Failure
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseCreateRewriteResponseListenerException() throws Exception {

String newDocumentField = "text_embedding";
Expand Down Expand Up @@ -1185,6 +1201,7 @@ public void onFailure(Exception e) {
* test throwing OpenSearchStatusException
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseOpenSearchStatusException() throws Exception {

String newDocumentField = "text_embedding";
Expand Down Expand Up @@ -1268,6 +1285,7 @@ public void onFailure(Exception e) {
* test throwing MLResourceNotFoundException
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseMLResourceNotFoundException() throws Exception {

String newDocumentField = "text_embedding";
Expand Down Expand Up @@ -1353,6 +1371,7 @@ public void onFailure(Exception e) {
* when there is one document, the combinedResponseListener calls onFailure
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseOneToOneWithOutputMappingsIgnoreFailure() throws Exception {

String newDocumentField = "text_embedding";
Expand Down Expand Up @@ -1414,6 +1433,7 @@ public void onFailure(Exception e) {
* when there is one document, the combinedResponseListener calls onFailure
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseOneToOneWithOutputMappingsMLTaskResponseExceptionIgnoreFailure() throws Exception {

String newDocumentField = "text_embedding";
Expand Down Expand Up @@ -1476,6 +1496,7 @@ public void onFailure(Exception e) {
* expect to run one prediction task and the rest 4 predictions tasks are not created
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseOneToOneWithOutputMappingsPredictException() throws Exception {

String newDocumentField = "text_embedding";
Expand Down Expand Up @@ -1532,6 +1553,7 @@ public void onFailure(Exception e) {
* expect to run one prediction task and the rest 4 predictions tasks are not created
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseOneToOneWithOutputMappingsPredictFail() throws Exception {

String newDocumentField = "text_embedding";
Expand Down Expand Up @@ -1594,6 +1616,7 @@ public void onFailure(Exception e) {
* then return original response
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseOneToOneWithOutputMappingsPredictFailIgnoreFailure() throws Exception {

String newDocumentField = "text_embedding";
Expand Down Expand Up @@ -1653,6 +1676,7 @@ public void onFailure(Exception e) {
* with one to one prediction, 5 documents in hits are calling 10 prediction tasks
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseOneToOneTwoRoundsPredictions() throws Exception {

String modelInputField = "inputs";
Expand Down Expand Up @@ -1783,6 +1807,7 @@ public void onFailure(Exception e) {
* expect to throw exception without further processing
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseOneToOneTwoRoundsPredictionsOneException() throws Exception {

String modelInputField = "inputs";
Expand Down Expand Up @@ -1883,6 +1908,7 @@ public void onFailure(Exception e) {
* expect to return document with second round prediction results
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseOneToOneTwoRoundsPredictionsOneExceptionIgnoreMissing() throws Exception {

String modelInputField = "inputs";
Expand Down Expand Up @@ -1993,6 +2019,7 @@ public void onFailure(Exception e) {
* expect to return document with second round prediction results
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseOneToOneTwoRoundsPredictionsOneExceptionIgnoreFailure() throws Exception {

String modelInputField = "inputs";
Expand Down Expand Up @@ -2086,6 +2113,7 @@ public void onFailure(Exception e) {
*
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseNoMappingSuccess() throws Exception {
MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor(
"model1",
Expand Down Expand Up @@ -2163,6 +2191,7 @@ public void onFailure(Exception e) {
*
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseEmptyMappingSuccess() throws Exception {
List<Map<String, String>> inputMap = new ArrayList<>();
Map<String, String> input = new HashMap<>();
Expand Down Expand Up @@ -2243,6 +2272,7 @@ public void onFailure(Exception e) {
*
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseListOfEmbeddingsSuccess() throws Exception {
/**
* sample response before inference
Expand Down Expand Up @@ -2330,6 +2360,7 @@ public void onFailure(Exception e) {
*
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseOverrideSameField() throws Exception {
/**
* sample response before inference
Expand Down Expand Up @@ -2416,6 +2447,7 @@ public void onFailure(Exception e) {
*
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseOverrideSameFieldFalse() throws Exception {
/**
* sample response before inference
Expand Down Expand Up @@ -2504,6 +2536,7 @@ public void onFailure(Exception e) {
*
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseListOfEmbeddingsMissingOneInputIgnoreMissingSuccess() throws Exception {
/**
* sample response before inference
Expand Down Expand Up @@ -2586,6 +2619,7 @@ public void onFailure(Exception e) {
*
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseListOfEmbeddingsMissingOneInputException() throws Exception {
/**
* sample response before inference
Expand Down Expand Up @@ -2670,6 +2704,7 @@ public void onFailure(Exception e) {
*
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseTwoRoundsOfPredictionSuccess() throws Exception {
String modelInputField = "inputs";
String modelOutputField = "response";
Expand Down Expand Up @@ -2767,6 +2802,7 @@ public void onFailure(Exception e) {
*
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseOneModelInputMultipleModelOutputs() throws Exception {
// one model input
String modelInputField = "inputs";
Expand Down Expand Up @@ -2853,6 +2889,7 @@ public void onFailure(Exception e) {
*
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponsePredictionException() throws Exception {
MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor(
"model1",
Expand Down Expand Up @@ -2898,6 +2935,7 @@ public void onFailure(Exception e) {
*
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponsePredictionFailed() throws Exception {
MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor(
"model1",
Expand Down Expand Up @@ -2948,6 +2986,7 @@ public void onFailure(Exception e) {
*
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponsePredictionExceptionIgnoreFailure() throws Exception {
MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor(
"model1",
Expand Down Expand Up @@ -2998,6 +3037,7 @@ public void onFailure(Exception e) {
*
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseEmptyHit() throws Exception {
MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor(
"model1",
Expand Down Expand Up @@ -3042,6 +3082,7 @@ public void onFailure(Exception e) {
*
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseHitWithNoSource() throws Exception {
MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor(
"model1",
Expand Down Expand Up @@ -3087,6 +3128,7 @@ public void onFailure(Exception e) {
* Exceptions happen when replaceHits to be one Hit Response
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseOneToOneMadeOneHitResponseExceptions() throws Exception {

String newDocumentField = "text_embedding";
Expand Down Expand Up @@ -3154,6 +3196,7 @@ public void onFailure(Exception e) {
* Exceptions happen when replaceHits and ignoreFailure return original response
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseOneToOneMadeOneHitResponseExceptionsIgnoreFailure() throws Exception {

String newDocumentField = "text_embedding";
Expand Down Expand Up @@ -3220,6 +3263,7 @@ public void onFailure(Exception e) {
* Exceptions happen when replaceHits
* @throws Exception if an error occurs during the test
*/
@Test
public void testProcessResponseOneToOneCombinedHitsExceptions() throws Exception {

String newDocumentField = "text_embedding";
Expand Down
Loading

0 comments on commit 81e6fe8

Please sign in to comment.