Skip to content

Commit

Permalink
[Backport 2.16] Add initial search request inference processor (#2731)
Browse files Browse the repository at this point in the history
  • Loading branch information
opensearch-trigger-bot[bot] authored Jul 24, 2024
1 parent 32a83e3 commit ca6bbe7
Show file tree
Hide file tree
Showing 10 changed files with 2,315 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.processor.MLInferenceIngestProcessor;
import org.opensearch.ml.processor.MLInferenceSearchRequestProcessor;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList;
import org.opensearch.ml.rest.RestMLCreateConnectorAction;
import org.opensearch.ml.rest.RestMLCreateControllerAction;
Expand Down Expand Up @@ -977,7 +978,11 @@ public Map<String, Processor.Factory<SearchRequestProcessor>> getRequestProcesso
GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE,
new GenerativeQARequestProcessor.Factory(() -> this.ragSearchPipelineEnabled)
);

requestProcessors
.put(
MLInferenceSearchRequestProcessor.TYPE,
new MLInferenceSearchRequestProcessor.Factory(parameters.client, parameters.namedXContentRegistry)
);
return requestProcessors;
}

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,13 @@ default Object getModelOutputValue(MLOutput mlOutput, String modelOutputFieldNam
return modelTensorOutputMap;
} else {
try {
return JsonPath.parse(modelTensorOutputMap).read(modelOutputFieldName);
Object modelOutputValue = JsonPath.parse(modelTensorOutputMap).read(modelOutputFieldName);
if (modelOutputValue == null) {
throw new IllegalArgumentException(
"model inference output cannot find such json path: " + modelOutputFieldName + " in " + modelTensorOutputMap
);
}
return modelOutputValue;
} catch (Exception e) {
if (ignoreMissing) {
return modelTensorOutputMap;
Expand Down Expand Up @@ -313,7 +319,6 @@ default List<String> writeNewDotPathForNestedObject(Object json, String dotPath)
* @return the converted dot path notation string
*/
default String convertToDotPath(String path) {

return path.replaceAll("\\[(\\d+)\\]", "$1\\.").replaceAll("\\['(.*?)']", "$1\\.").replaceAll("^\\$", "").replaceAll("\\.$", "");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.opensearch.ml.common.spi.MLCommonsExtension;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.engine.tools.MLModelTool;
import org.opensearch.ml.processor.MLInferenceSearchRequestProcessor;
import org.opensearch.plugins.ExtensiblePlugin;
import org.opensearch.plugins.SearchPipelinePlugin;
import org.opensearch.plugins.SearchPlugin;
Expand Down Expand Up @@ -73,10 +74,11 @@ public void testGetSearchExts() {
public void testGetRequestProcessors() {
SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class);
Map<String, ?> requestProcessors = plugin.getRequestProcessors(parameters);
assertEquals(1, requestProcessors.size());
assertEquals(2, requestProcessors.size());
assertTrue(
requestProcessors.get(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE) instanceof GenerativeQARequestProcessor.Factory
);
assertTrue(requestProcessors.get(MLInferenceSearchRequestProcessor.TYPE) instanceof MLInferenceSearchRequestProcessor.Factory);
}

@Test
Expand Down
Loading

0 comments on commit ca6bbe7

Please sign in to comment.