Skip to content

Commit

Permalink
check if response is in type of MLInferenceSearchResponse
Browse files Browse the repository at this point in the history
Signed-off-by: Mingshi Liu <[email protected]>
  • Loading branch information
mingshl committed Oct 17, 2024
1 parent 81e6fe8 commit f7682e0
Show file tree
Hide file tree
Showing 2 changed files with 269 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,14 @@
import static org.opensearch.ml.processor.MLInferenceIngestProcessor.OVERRIDE;

import java.io.IOException;
import java.util.*;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;

import org.apache.logging.log4j.LogManager;
Expand Down Expand Up @@ -156,18 +163,27 @@ public void processResponseAsync(
// if many to one, run rewriteResponseDocuments
if (!oneToOne) {
// use MLInferenceSearchResponseProcessor to allow writing to extension
MLInferenceSearchResponse mLInferenceSearchResponse = new MLInferenceSearchResponse(
null,
response.getInternalResponse(),
response.getScrollId(),
response.getTotalShards(),
response.getSuccessfulShards(),
response.getSkippedShards(),
response.getSuccessfulShards(),
response.getShardFailures(),
response.getClusters()
);
rewriteResponseDocuments(mLInferenceSearchResponse, responseListener);
// check if the search response is in the type of MLInferenceSearchResponse
// if not, initiate a new one MLInferenceSearchResponse
MLInferenceSearchResponse mlInferenceSearchResponse;

if (response instanceof MLInferenceSearchResponse) {
mlInferenceSearchResponse = (MLInferenceSearchResponse) response;
} else {
mlInferenceSearchResponse = new MLInferenceSearchResponse(
null,
response.getInternalResponse(),
response.getScrollId(),
response.getTotalShards(),
response.getSuccessfulShards(),
response.getSkippedShards(),
response.getSuccessfulShards(),
response.getShardFailures(),
response.getClusters()
);
}

rewriteResponseDocuments(mlInferenceSearchResponse, responseListener);
} else {
// if one to one, make one hit search response and run rewriteResponseDocuments
GroupedActionListener<SearchResponse> combineResponseListener = getCombineResponseGroupedActionListener(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3327,6 +3327,246 @@ public void onFailure(Exception e) {
responseProcessor.processResponseAsync(request, mockResponse, responseContext, listener);
}

/**
* Tests the processResponseAsync method when the input is a regular SearchResponse.
*
* This test verifies that when a regular SearchResponse is passed to the method,
* it attempts to create a new MLInferenceSearchResponse object.
*/
@Test
public void testProcessResponseAsync_WithRegularSearchResponse() {
String modelInputField = "inputs";
String originalDocumentField = "text";
String newDocumentField = "text_embedding";
String modelOutputField = "response";

SearchResponse response = getSearchResponse(5, true, originalDocumentField);
Map<String, Object> params = new HashMap<>();
params.put("llm_response", "answer");
MLInferenceSearchResponse mLInferenceSearchResponse = new MLInferenceSearchResponse(
params,
response.getInternalResponse(),
response.getScrollId(),
response.getTotalShards(),
response.getSuccessfulShards(),
response.getSkippedShards(),
response.getSuccessfulShards(),
response.getShardFailures(),
response.getClusters()
);

MLInferenceSearchResponseProcessor responseProcessor = getMlInferenceSearchResponseProcessorSinglePairMapping(
modelOutputField,
modelInputField,
originalDocumentField,
newDocumentField,
false,
false,
false
);
SearchRequest request = getSearchRequest();
ModelTensor modelTensor = ModelTensor
.builder()
.dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0)))
.build();
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();

doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build());
return null;
}).when(client).execute(any(), any(), any());

ActionListener<SearchResponse> listener = new ActionListener<>() {
@Override
public void onResponse(SearchResponse newSearchResponse) {
MLInferenceSearchResponse responseAfterProcessor = (MLInferenceSearchResponse) newSearchResponse;
assertEquals(responseAfterProcessor.getHits().getHits().length, 5);
assertEquals(responseAfterProcessor.getHits().getHits()[0].getSourceAsMap().get("text_embedding"), 0.0);
assertEquals(responseAfterProcessor.getHits().getHits()[1].getSourceAsMap().get("text_embedding"), 1.0);
assertEquals(responseAfterProcessor.getHits().getHits()[2].getSourceAsMap().get("text_embedding"), 2.0);
assertEquals(responseAfterProcessor.getHits().getHits()[3].getSourceAsMap().get("text_embedding"), 3.0);
assertEquals(responseAfterProcessor.getHits().getHits()[4].getSourceAsMap().get("text_embedding"), 4.0);
assertEquals(responseAfterProcessor.getParams(), params);
}

@Override
public void onFailure(Exception e) {
throw new RuntimeException(e);
}
};

responseProcessor.processResponseAsync(request, mLInferenceSearchResponse, responseContext, listener);

}

/**
* Tests the processResponseAsync method when the input is already an MLInferenceSearchResponse.
*
* This test verifies that when an MLInferenceSearchResponse is passed to the method,
* and the params is being passed over
*/
@Test
public void testProcessResponseAsync_WithMLInferenceSearchResponse() {
String modelInputField = "inputs";
String originalDocumentField = "text";
String newDocumentField = "text_embedding";
String modelOutputField = "response";

SearchResponse response = getSearchResponse(5, true, originalDocumentField);
Map<String, Object> params = new HashMap<>();
params.put("llm_response", "answer");
MLInferenceSearchResponse mLInferenceSearchResponse = new MLInferenceSearchResponse(
params,
response.getInternalResponse(),
response.getScrollId(),
response.getTotalShards(),
response.getSuccessfulShards(),
response.getSkippedShards(),
response.getSuccessfulShards(),
response.getShardFailures(),
response.getClusters()
);

MLInferenceSearchResponseProcessor responseProcessor = getMlInferenceSearchResponseProcessorSinglePairMapping(
modelOutputField,
modelInputField,
originalDocumentField,
newDocumentField,
false,
false,
false
);
SearchRequest request = getSearchRequest();
ModelTensor modelTensor = ModelTensor
.builder()
.dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0)))
.build();
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();

doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build());
return null;
}).when(client).execute(any(), any(), any());

ActionListener<SearchResponse> listener = new ActionListener<>() {
@Override
public void onResponse(SearchResponse newSearchResponse) {
MLInferenceSearchResponse responseAfterProcessor = (MLInferenceSearchResponse) newSearchResponse;
assertEquals(responseAfterProcessor.getHits().getHits().length, 5);
assertEquals(responseAfterProcessor.getHits().getHits()[0].getSourceAsMap().get("text_embedding"), 0.0);
assertEquals(responseAfterProcessor.getHits().getHits()[1].getSourceAsMap().get("text_embedding"), 1.0);
assertEquals(responseAfterProcessor.getHits().getHits()[2].getSourceAsMap().get("text_embedding"), 2.0);
assertEquals(responseAfterProcessor.getHits().getHits()[3].getSourceAsMap().get("text_embedding"), 3.0);
assertEquals(responseAfterProcessor.getHits().getHits()[4].getSourceAsMap().get("text_embedding"), 4.0);
assertEquals(responseAfterProcessor.getParams(), params);
}

@Override
public void onFailure(Exception e) {
throw new RuntimeException(e);
}
};

responseProcessor.processResponseAsync(request, mLInferenceSearchResponse, responseContext, listener);

}

/**
* Tests the processResponseAsync method when the input is already an MLInferenceSearchResponse.
*
* This test verifies that when an MLInferenceSearchResponse is passed to the method,
* and the params is being passed over and new params is added
*/
@Test
public void testProcessResponseAsync_WriteExtensionToMLInferenceSearchResponse() {
String documentField = "text";
String modelInputField = "context";
List<Map<String, String>> inputMap = new ArrayList<>();
Map<String, String> input = new HashMap<>();
input.put(modelInputField, documentField);
inputMap.add(input);

String newDocumentField = "ext.ml_inference.summary";
String modelOutputField = "response";
List<Map<String, String>> outputMap = new ArrayList<>();
Map<String, String> output = new HashMap<>();
output.put(newDocumentField, modelOutputField);
outputMap.add(output);
Map<String, String> modelConfig = new HashMap<>();
modelConfig
.put(
"prompt",
"\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.context}. \\n\\n Human: please summarize the documents \\n\\n Assistant:"
);
MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor(
"model1",
inputMap,
outputMap,
modelConfig,
DEFAULT_MAX_PREDICTION_TASKS,
PROCESSOR_TAG,
DESCRIPTION,
false,
"remote",
false,
false,
false,
"{ \"parameters\": ${ml_inference.parameters} }",
client,
TEST_XCONTENT_REGISTRY_FOR_QUERY,
false
);
SearchResponse response = getSearchResponse(5, true, documentField);
Map<String, Object> params = new HashMap<>();
params.put("llm_response", "answer");
MLInferenceSearchResponse mLInferenceSearchResponse = new MLInferenceSearchResponse(
params,
response.getInternalResponse(),
response.getScrollId(),
response.getTotalShards(),
response.getSuccessfulShards(),
response.getSkippedShards(),
response.getSuccessfulShards(),
response.getShardFailures(),
response.getClusters()
);

SearchRequest request = getSearchRequest();
ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "there is 1 value")).build();
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();

doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build());
return null;
}).when(client).execute(any(), any(), any());

ActionListener<SearchResponse> listener = new ActionListener<>() {
@Override
public void onResponse(SearchResponse newSearchResponse) {
MLInferenceSearchResponse responseAfterProcessor = (MLInferenceSearchResponse) newSearchResponse;
assertEquals(responseAfterProcessor.getHits().getHits().length, 5);
Map<String, Object> newParams = new HashMap<>();
newParams.put("llm_response", "answer");
newParams.put("summary", "there is 1 value");
assertEquals(responseAfterProcessor.getParams(), newParams);
}

@Override
public void onFailure(Exception e) {
throw new RuntimeException(e);
}
};

responseProcessor.processResponseAsync(request, mLInferenceSearchResponse, responseContext, listener);

}

private static SearchRequest getSearchRequest() {
QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo");
SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery);
Expand Down

0 comments on commit f7682e0

Please sign in to comment.