Skip to content

Commit

Permalink
add initial MLInferenceSearchResponseProcessor (#2688)
Browse files Browse the repository at this point in the history
* add MLInferenceSearchResponseProcessor

Signed-off-by: Mingshi Liu <[email protected]>

* add ITs

Signed-off-by: Mingshi Liu <[email protected]>

* add code coverage

Signed-off-by: Mingshi Liu <[email protected]>

* add many_to_one flag

Signed-off-by: Mingshi Liu <[email protected]>

* avoid import *

Signed-off-by: Mingshi Liu <[email protected]>

* remove extra hits

Signed-off-by: Mingshi Liu <[email protected]>

* spotlessApply

Signed-off-by: Mingshi Liu <[email protected]>

* remove extra brackets

Signed-off-by: Mingshi Liu <[email protected]>

---------

Signed-off-by: Mingshi Liu <[email protected]>
(cherry picked from commit 01084b4)
  • Loading branch information
mingshl authored and github-actions[bot] committed Jul 24, 2024
1 parent ca6bbe7 commit c3a7f28
Show file tree
Hide file tree
Showing 10 changed files with 2,871 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.processor.MLInferenceIngestProcessor;
import org.opensearch.ml.processor.MLInferenceSearchRequestProcessor;
import org.opensearch.ml.processor.MLInferenceSearchResponseProcessor;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList;
import org.opensearch.ml.rest.RestMLCreateConnectorAction;
import org.opensearch.ml.rest.RestMLCreateControllerAction;
Expand Down Expand Up @@ -996,6 +997,12 @@ public Map<String, Processor.Factory<SearchResponseProcessor>> getResponseProces
new GenerativeQAResponseProcessor.Factory(this.client, () -> this.ragSearchPipelineEnabled)
);

responseProcessors
.put(
MLInferenceSearchResponseProcessor.TYPE,
new MLInferenceSearchResponseProcessor.Factory(parameters.client, parameters.namedXContentRegistry)
);

return responseProcessors;
}

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,15 @@ default String toString(Object originalFieldValue) {
return StringUtils.toJson(originalFieldValue);
}

default boolean hasField(Object json, String path) {
Object value = JsonPath.using(suppressExceptionConfiguration).parse(json).read(path);

if (value != null) {
return true;
}
return false;
}

/**
* Writes a new dot path for a nested object within the given JSON object.
* This method is useful when dealing with arrays or nested objects in the JSON structure.
Expand Down Expand Up @@ -321,5 +330,4 @@ default List<String> writeNewDotPathForNestedObject(Object json, String dotPath)
default String convertToDotPath(String path) {
return path.replaceAll("\\[(\\d+)\\]", "$1\\.").replaceAll("\\['(.*?)']", "$1\\.").replaceAll("^\\$", "").replaceAll("\\.$", "");
}

}
53 changes: 53 additions & 0 deletions plugin/src/main/java/org/opensearch/ml/utils/MapUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.utils;

import java.util.HashMap;
import java.util.Map;

public class MapUtils {

/**
* Increments the counter for the given key in the specified version.
* If the key doesn't exist, it initializes the counter to 0.
*
* @param version the version of the counter
* @param key the key for which the counter needs to be incremented
*/
public static void incrementCounter(Map<Integer, Map<String, Integer>> versionedCounters, int version, String key) {
Map<String, Integer> counters = versionedCounters.computeIfAbsent(version, k -> new HashMap<>());
counters.put(key, counters.getOrDefault(key, -1) + 1);
}

/**
* Retrieves the counter value for the given key in the specified version.
* If the key doesn't exist, it returns 0.
*
* @param version the version of the counter
* @param key the key for which the counter needs to be retrieved
* @return the counter value for the given key
*/
public static int getCounter(Map<Integer, Map<String, Integer>> versionedCounters, int version, String key) {
Map<String, Integer> counters = versionedCounters.get(version);
return counters != null ? counters.getOrDefault(key, -1) : 0;
}

/**
* Increments the counter value for the given key in the provided counters map.
* If the key does not exist in the map, it is added with an initial counter value of 0.
*
* @param counters A map that stores integer counters for each integer key.
* @param key The integer key for which the counter needs to be incremented.
*/
public static void incrementCounter(Map<Integer, Integer> counters, int key) {
counters.put(key, counters.getOrDefault(key, 0) + 1);
}

public static int getCounter(Map<Integer, Integer> counters, int key) {
return counters.getOrDefault(key, 0);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.utils;

import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchResponseSections;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.search.profile.SearchProfileShardResults;

public class SearchResponseUtil {
private SearchResponseUtil() {}

/**
* Construct a new {@link SearchResponse} based on an existing one, replacing just the {@link SearchHits}.
* @param newHits new {@link SearchHits}
* @param response the existing search response
* @return a new search response where the {@link SearchHits} has been replaced
*/
public static SearchResponse replaceHits(SearchHits newHits, SearchResponse response) {
SearchResponseSections searchResponseSections;
if (response.getAggregations() == null || response.getAggregations() instanceof InternalAggregations) {
// We either have no aggregations, or we have Writeable InternalAggregations.
// Either way, we can produce a Writeable InternalSearchResponse.
searchResponseSections = new InternalSearchResponse(
newHits,
(InternalAggregations) response.getAggregations(),
response.getSuggest(),
new SearchProfileShardResults(response.getProfileResults()),
response.isTimedOut(),
response.isTerminatedEarly(),
response.getNumReducePhases()
);
} else {
// We have non-Writeable Aggregations, so the whole SearchResponseSections is non-Writeable.
searchResponseSections = new SearchResponseSections(
newHits,
response.getAggregations(),
response.getSuggest(),
response.isTimedOut(),
response.isTerminatedEarly(),
new SearchProfileShardResults(response.getProfileResults()),
response.getNumReducePhases()
);
}

return new SearchResponse(
searchResponseSections,
response.getScrollId(),
response.getTotalShards(),
response.getSuccessfulShards(),
response.getSkippedShards(),
response.getTook().millis(),
response.getShardFailures(),
response.getClusters(),
response.pointInTimeId()
);
}

/**
* Convenience method when only replacing the {@link SearchHit} array within the {@link SearchHits} in a {@link SearchResponse}.
* @param newHits the new array of {@link SearchHit} elements.
* @param response the search response to update
* @return a {@link SearchResponse} where the underlying array of {@link SearchHit} within the {@link SearchHits} has been replaced.
*/
public static SearchResponse replaceHits(SearchHit[] newHits, SearchResponse response) {
if (response.getHits() == null) {
throw new IllegalStateException("Response must have hits");
}
SearchHits searchHits = new SearchHits(
newHits,
response.getHits().getTotalHits(),
response.getHits().getMaxScore(),
response.getHits().getSortFields(),
response.getHits().getCollapseField(),
response.getHits().getCollapseValues()
);
return replaceHits(searchHits, response);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.engine.tools.MLModelTool;
import org.opensearch.ml.processor.MLInferenceSearchRequestProcessor;
import org.opensearch.ml.processor.MLInferenceSearchResponseProcessor;
import org.opensearch.plugins.ExtensiblePlugin;
import org.opensearch.plugins.SearchPipelinePlugin;
import org.opensearch.plugins.SearchPlugin;
Expand Down Expand Up @@ -85,10 +86,11 @@ public void testGetRequestProcessors() {
public void testGetResponseProcessors() {
SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class);
Map<String, ?> responseProcessors = plugin.getResponseProcessors(parameters);
assertEquals(1, responseProcessors.size());
assertEquals(2, responseProcessors.size());
assertTrue(
responseProcessors.get(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE) instanceof GenerativeQAResponseProcessor.Factory
);
assertTrue(responseProcessors.get(MLInferenceSearchResponseProcessor.TYPE) instanceof MLInferenceSearchResponseProcessor.Factory);
}

@Test
Expand Down
Loading

0 comments on commit c3a7f28

Please sign in to comment.