Skip to content

Commit

Permalink
Add initial search request inference processor (opensearch-project#2616)
Browse files Browse the repository at this point in the history
* add initial search request inference processor

Signed-off-by: Mingshi Liu <[email protected]>

* Add ITs for MLInferenceSearchRequestProcessor

Signed-off-by: Mingshi Liu <[email protected]>

* skip running OPENAI when key is not present and fix yaml test issue

Signed-off-by: Mingshi Liu <[email protected]>

---------

Signed-off-by: Mingshi Liu <[email protected]>
  • Loading branch information
mingshl committed Jul 24, 2024
1 parent f104cca commit cc786f7
Show file tree
Hide file tree
Showing 9 changed files with 2,308 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.processor.MLInferenceIngestProcessor;
import org.opensearch.ml.processor.MLInferenceSearchRequestProcessor;
import org.opensearch.ml.processor.MLInferenceSearchResponseProcessor;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList;
import org.opensearch.ml.rest.RestMLCreateConnectorAction;
Expand Down Expand Up @@ -977,7 +978,11 @@ public Map<String, Processor.Factory<SearchRequestProcessor>> getRequestProcesso
GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE,
new GenerativeQARequestProcessor.Factory(() -> this.ragSearchPipelineEnabled)
);

requestProcessors
.put(
MLInferenceSearchRequestProcessor.TYPE,
new MLInferenceSearchRequestProcessor.Factory(parameters.client, parameters.namedXContentRegistry)
);
return requestProcessors;
}

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,13 @@ default Object getModelOutputValue(MLOutput mlOutput, String modelOutputFieldNam
return modelTensorOutputMap;
} else {
try {
return JsonPath.parse(modelTensorOutputMap).read(modelOutputFieldName);
Object modelOutputValue = JsonPath.parse(modelTensorOutputMap).read(modelOutputFieldName);
if (modelOutputValue == null) {
throw new IllegalArgumentException(
"model inference output cannot find such json path: " + modelOutputFieldName + " in " + modelTensorOutputMap
);
}
return modelOutputValue;
} catch (Exception e) {
if (ignoreMissing) {
return modelTensorOutputMap;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.opensearch.ml.common.spi.MLCommonsExtension;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.engine.tools.MLModelTool;
import org.opensearch.ml.processor.MLInferenceSearchRequestProcessor;
import org.opensearch.ml.processor.MLInferenceSearchResponseProcessor;
import org.opensearch.plugins.ExtensiblePlugin;
import org.opensearch.plugins.SearchPipelinePlugin;
Expand Down Expand Up @@ -74,10 +75,11 @@ public void testGetSearchExts() {
public void testGetRequestProcessors() {
SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class);
Map<String, ?> requestProcessors = plugin.getRequestProcessors(parameters);
assertEquals(1, requestProcessors.size());
assertEquals(2, requestProcessors.size());
assertTrue(
requestProcessors.get(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE) instanceof GenerativeQARequestProcessor.Factory
);
assertTrue(requestProcessors.get(MLInferenceSearchRequestProcessor.TYPE) instanceof MLInferenceSearchRequestProcessor.Factory);
}

@Test
Expand Down
Loading

0 comments on commit cc786f7

Please sign in to comment.