Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.16] add initial MLInferenceSearchResponseProcessor #2734

Merged
merged 2 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.processor.MLInferenceIngestProcessor;
import org.opensearch.ml.processor.MLInferenceSearchRequestProcessor;
import org.opensearch.ml.processor.MLInferenceSearchResponseProcessor;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList;
import org.opensearch.ml.rest.RestMLCreateConnectorAction;
import org.opensearch.ml.rest.RestMLCreateControllerAction;
Expand Down Expand Up @@ -996,6 +997,12 @@ public Map<String, Processor.Factory<SearchResponseProcessor>> getResponseProces
new GenerativeQAResponseProcessor.Factory(this.client, () -> this.ragSearchPipelineEnabled)
);

responseProcessors
.put(
MLInferenceSearchResponseProcessor.TYPE,
new MLInferenceSearchResponseProcessor.Factory(parameters.client, parameters.namedXContentRegistry)
);

return responseProcessors;
}

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,15 @@ default String toString(Object originalFieldValue) {
return StringUtils.toJson(originalFieldValue);
}

default boolean hasField(Object json, String path) {
Object value = JsonPath.using(suppressExceptionConfiguration).parse(json).read(path);

if (value != null) {
return true;
}
return false;
}

/**
* Writes a new dot path for a nested object within the given JSON object.
* This method is useful when dealing with arrays or nested objects in the JSON structure.
Expand Down Expand Up @@ -321,5 +330,4 @@ default List<String> writeNewDotPathForNestedObject(Object json, String dotPath)
default String convertToDotPath(String path) {
return path.replaceAll("\\[(\\d+)\\]", "$1\\.").replaceAll("\\['(.*?)']", "$1\\.").replaceAll("^\\$", "").replaceAll("\\.$", "");
}

}
53 changes: 53 additions & 0 deletions plugin/src/main/java/org/opensearch/ml/utils/MapUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.utils;

import java.util.HashMap;
import java.util.Map;

public class MapUtils {

/**
* Increments the counter for the given key in the specified version.
* If the key doesn't exist, it initializes the counter to 0.
*
* @param version the version of the counter
* @param key the key for which the counter needs to be incremented
*/
public static void incrementCounter(Map<Integer, Map<String, Integer>> versionedCounters, int version, String key) {
Map<String, Integer> counters = versionedCounters.computeIfAbsent(version, k -> new HashMap<>());
counters.put(key, counters.getOrDefault(key, -1) + 1);
}

/**
* Retrieves the counter value for the given key in the specified version.
* If the key doesn't exist, it returns 0.
*
* @param version the version of the counter
* @param key the key for which the counter needs to be retrieved
* @return the counter value for the given key
*/
public static int getCounter(Map<Integer, Map<String, Integer>> versionedCounters, int version, String key) {
Map<String, Integer> counters = versionedCounters.get(version);
return counters != null ? counters.getOrDefault(key, -1) : 0;
}

/**
* Increments the counter value for the given key in the provided counters map.
* If the key does not exist in the map, it is added with an initial counter value of 0.
*
* @param counters A map that stores integer counters for each integer key.
* @param key The integer key for which the counter needs to be incremented.
*/
public static void incrementCounter(Map<Integer, Integer> counters, int key) {
counters.put(key, counters.getOrDefault(key, 0) + 1);
}

public static int getCounter(Map<Integer, Integer> counters, int key) {
return counters.getOrDefault(key, 0);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.utils;

import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchResponseSections;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.search.profile.SearchProfileShardResults;

public class SearchResponseUtil {
private SearchResponseUtil() {}

/**
* Construct a new {@link SearchResponse} based on an existing one, replacing just the {@link SearchHits}.
* @param newHits new {@link SearchHits}
* @param response the existing search response
* @return a new search response where the {@link SearchHits} has been replaced
*/
public static SearchResponse replaceHits(SearchHits newHits, SearchResponse response) {
SearchResponseSections searchResponseSections;
if (response.getAggregations() == null || response.getAggregations() instanceof InternalAggregations) {
// We either have no aggregations, or we have Writeable InternalAggregations.
// Either way, we can produce a Writeable InternalSearchResponse.
searchResponseSections = new InternalSearchResponse(
newHits,
(InternalAggregations) response.getAggregations(),
response.getSuggest(),
new SearchProfileShardResults(response.getProfileResults()),
response.isTimedOut(),
response.isTerminatedEarly(),
response.getNumReducePhases()
);
} else {
// We have non-Writeable Aggregations, so the whole SearchResponseSections is non-Writeable.
searchResponseSections = new SearchResponseSections(
newHits,
response.getAggregations(),
response.getSuggest(),
response.isTimedOut(),
response.isTerminatedEarly(),
new SearchProfileShardResults(response.getProfileResults()),
response.getNumReducePhases()
);
}

return new SearchResponse(
searchResponseSections,
response.getScrollId(),
response.getTotalShards(),
response.getSuccessfulShards(),
response.getSkippedShards(),
response.getTook().millis(),
response.getShardFailures(),
response.getClusters(),
response.pointInTimeId()
);
}

/**
* Convenience method when only replacing the {@link SearchHit} array within the {@link SearchHits} in a {@link SearchResponse}.
* @param newHits the new array of {@link SearchHit} elements.
* @param response the search response to update
* @return a {@link SearchResponse} where the underlying array of {@link SearchHit} within the {@link SearchHits} has been replaced.
*/
public static SearchResponse replaceHits(SearchHit[] newHits, SearchResponse response) {
if (response.getHits() == null) {
throw new IllegalStateException("Response must have hits");
}
SearchHits searchHits = new SearchHits(
newHits,
response.getHits().getTotalHits(),
response.getHits().getMaxScore(),
response.getHits().getSortFields(),
response.getHits().getCollapseField(),
response.getHits().getCollapseValues()
);
return replaceHits(searchHits, response);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.engine.tools.MLModelTool;
import org.opensearch.ml.processor.MLInferenceSearchRequestProcessor;
import org.opensearch.ml.processor.MLInferenceSearchResponseProcessor;
import org.opensearch.plugins.ExtensiblePlugin;
import org.opensearch.plugins.SearchPipelinePlugin;
import org.opensearch.plugins.SearchPlugin;
Expand Down Expand Up @@ -85,10 +86,11 @@ public void testGetRequestProcessors() {
public void testGetResponseProcessors() {
SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class);
Map<String, ?> responseProcessors = plugin.getResponseProcessors(parameters);
assertEquals(1, responseProcessors.size());
assertEquals(2, responseProcessors.size());
assertTrue(
responseProcessors.get(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE) instanceof GenerativeQAResponseProcessor.Factory
);
assertTrue(responseProcessors.get(MLInferenceSearchResponseProcessor.TYPE) instanceof MLInferenceSearchResponseProcessor.Factory);
}

@Test
Expand Down
Loading
Loading