From 310d023c519959c05767a9d94e5683df105c92b9 Mon Sep 17 00:00:00 2001 From: Xun Zhang Date: Mon, 22 Jul 2024 14:56:35 -0700 Subject: [PATCH] Add Batch Prediction Mode in the Connector Framework for batch inference (#2661) (#2701) * add batch predict job actiontype in connector * remove async and streaming mode temporarily * rename predict mode to action type * use method name in the url path for action type * add stats for actionType and more UTs * add bwx for actiontype * address more comments --------- Signed-off-by: Xun Zhang --- .../org/opensearch/ml/common/CommonValue.java | 1 + .../ml/common/connector/ConnectorAction.java | 32 ++++++++++- .../remote/RemoteInferenceInputDataSet.java | 31 +++++++++- .../opensearch/ml/common/input/MLInput.java | 14 +++++ .../input/remote/RemoteInferenceMLInput.java | 11 +++- .../common/connector/ConnectorActionTest.java | 15 +++++ .../RemoteInferenceInputDataSetTest.java | 21 +++++++ .../ml/common/input/MLInputTest.java | 57 ++++++++++++++++++- .../remote/RemoteInferenceMLInputTest.java | 3 +- .../memory/index/InteractionsIndexTests.java | 1 - .../engine/algorithms/remote/RemoteModel.java | 9 ++- .../MLSdkAsyncHttpResponseHandlerTest.java | 2 - .../ml/rest/RestMLPredictionAction.java | 14 ++++- .../ml/settings/MLCommonsSettings.java | 1 + .../org/opensearch/ml/stats/ActionName.java | 3 +- .../ml/task/MLPredictTaskRunner.java | 43 ++++++++------ .../opensearch/ml/utils/RestActionUtils.java | 21 +++++++ .../ml/rest/RestMLPredictionActionTests.java | 29 ++++++++++ .../org/opensearch/ml/utils/TestHelper.java | 37 ++++++++++++ 19 files changed, 314 insertions(+), 31 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 192b7a9737..93564a7d4b 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -538,4 +538,5 @@ public class CommonValue { public static final Version VERSION_2_12_0 = Version.fromString("2.12.0"); public static final Version VERSION_2_13_0 = Version.fromString("2.13.0"); public static final Version VERSION_2_14_0 = Version.fromString("2.14.0"); + public static final Version VERSION_2_16_0 = Version.fromString("2.16.0"); } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java index e424914b4f..9be290d126 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java @@ -14,10 +14,13 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; import java.io.IOException; +import java.util.HashSet; import java.util.Locale; import java.util.Map; +import java.util.Set; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; @@ -183,6 +186,33 @@ public static ConnectorAction parse(XContentParser parser) throws IOException { public enum ActionType { PREDICT, - EXECUTE + EXECUTE, + BATCH_PREDICT; + + public static ActionType from(String value) { + try { + return ActionType.valueOf(value.toUpperCase(Locale.ROOT)); + } catch (Exception e) { + throw new IllegalArgumentException("Wrong Action Type of " + value); + } + } + + private static final HashSet MODEL_SUPPORT_ACTIONS = new HashSet<>(Set.of( + PREDICT, + BATCH_PREDICT + )); + + public static boolean isValidActionInModelPrediction(ActionType actionType) { + return MODEL_SUPPORT_ACTIONS.contains(actionType); + } + + public static boolean isValidAction(String action) { + try { + ActionType.valueOf(action.toUpperCase()); + return true; + } catch (IllegalArgumentException e) { + return false; + } + } } } diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java b/common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java index 85a1b72022..3023d5c3fc 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java +++ b/common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java @@ -10,8 +10,11 @@ import lombok.Builder; import lombok.Getter; import lombok.Setter; +import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.CommonValue; +import org.opensearch.ml.common.connector.ConnectorAction.ActionType; import org.opensearch.ml.common.annotation.InputDataSet; import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.dataset.MLInputDataset; @@ -19,32 +22,56 @@ @Getter @InputDataSet(MLInputDataType.REMOTE) public class RemoteInferenceInputDataSet extends MLInputDataset { - + private static final Version MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG = CommonValue.VERSION_2_16_0; @Setter private Map parameters; + @Setter + private ActionType actionType; @Builder(toBuilder = true) - public RemoteInferenceInputDataSet(Map parameters) { + public RemoteInferenceInputDataSet(Map parameters, ActionType actionType) { super(MLInputDataType.REMOTE); this.parameters = parameters; + this.actionType = actionType; + } + + public RemoteInferenceInputDataSet(Map parameters) { + this(parameters, null); } public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException { super(MLInputDataType.REMOTE); + Version streamInputVersion = streamInput.getVersion(); if (streamInput.readBoolean()) { parameters = streamInput.readMap(s -> s.readString(), s-> s.readString()); } + if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG)) { + if (streamInput.readBoolean()) { + actionType = streamInput.readEnum(ActionType.class); + } else { + this.actionType = null; + } + } } @Override public void writeTo(StreamOutput streamOutput) throws IOException { super.writeTo(streamOutput); + Version streamOutputVersion = streamOutput.getVersion(); if (parameters != null) { streamOutput.writeBoolean(true); streamOutput.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeString); } else { streamOutput.writeBoolean(false); } + if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG)) { + if (actionType != null) { + streamOutput.writeBoolean(true); + streamOutput.writeEnum(actionType); + } else { + streamOutput.writeBoolean(false); + } + } } } diff --git a/common/src/main/java/org/opensearch/ml/common/input/MLInput.java b/common/src/main/java/org/opensearch/ml/common/input/MLInput.java index ed28cdfc1f..4bf166f9b6 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/MLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/MLInput.java @@ -13,6 +13,7 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.connector.ConnectorAction.ActionType; import org.opensearch.ml.common.MLCommonsClassLoader; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DefaultDataFrame; @@ -35,6 +36,7 @@ import java.util.Map; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.input.remote.RemoteInferenceMLInput.ACTION_TYPE_FIELD; /** * ML input data: algorithm name, parameters and input data set. @@ -196,6 +198,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws RemoteInferenceInputDataSet remoteInferenceInputDataSet = (RemoteInferenceInputDataSet) this.inputDataset; Map parameters = remoteInferenceInputDataSet.getParameters(); builder.field(PARAMETERS_FIELD, parameters); + builder.field(ACTION_TYPE_FIELD, remoteInferenceInputDataSet.getActionType()); break; default: break; @@ -206,6 +209,17 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } + public static MLInput parse(XContentParser parser, String inputAlgoName, ActionType actionType) throws IOException { + MLInput mlInput = parse(parser, inputAlgoName); + if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) { + RemoteInferenceInputDataSet remoteInferenceInputDataSet = (RemoteInferenceInputDataSet)mlInput.getInputDataset(); + if (remoteInferenceInputDataSet.getActionType() == null) { + remoteInferenceInputDataSet.setActionType(actionType); + } + } + return mlInput; + } + public static MLInput parse(XContentParser parser, String inputAlgoName) throws IOException { String algorithmName = inputAlgoName.toUpperCase(Locale.ROOT); FunctionName algorithm = FunctionName.from(algorithmName); diff --git a/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java index da4a9ad73d..cd45cb19cb 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java @@ -9,6 +9,7 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.connector.ConnectorAction.ActionType; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.utils.StringUtils; @@ -21,6 +22,7 @@ @org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.REMOTE}) public class RemoteInferenceMLInput extends MLInput { public static final String PARAMETERS_FIELD = "parameters"; + public static final String ACTION_TYPE_FIELD = "action_type"; public RemoteInferenceMLInput(StreamInput in) throws IOException { super(in); @@ -34,6 +36,8 @@ public void writeTo(StreamOutput out) throws IOException { public RemoteInferenceMLInput(XContentParser parser, FunctionName functionName) throws IOException { super(); this.algorithm = functionName; + Map parameters = null; + ActionType actionType = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); @@ -41,14 +45,17 @@ public RemoteInferenceMLInput(XContentParser parser, FunctionName functionName) switch (fieldName) { case PARAMETERS_FIELD: - Map parameters = StringUtils.getParameterMap(parser.map()); - inputDataset = new RemoteInferenceInputDataSet(parameters); + parameters = StringUtils.getParameterMap(parser.map()); + break; + case ACTION_TYPE_FIELD: + actionType = ActionType.from(parser.text()); break; default: parser.skipChildren(); break; } } + inputDataset = new RemoteInferenceInputDataSet(parameters, actionType); } } diff --git a/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java index e05dd76dcd..2a7927a269 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java @@ -24,6 +24,8 @@ import java.util.HashMap; import java.util.Map; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.isValidActionInModelPrediction; + public class ConnectorActionTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -140,4 +142,17 @@ public void parse() throws IOException { Assert.assertEquals("connector.pre_process.openai.embedding", action.getPreProcessFunction()); Assert.assertEquals("connector.post_process.openai.embedding", action.getPostProcessFunction()); } + + @Test + public void test_wrongActionType() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Wrong Action Type"); + ConnectorAction.ActionType.from("badAction"); + } + + @Test + public void test_invalidActionInModelPrediction() { + ConnectorAction.ActionType actionType = ConnectorAction.ActionType.from("execute"); + Assert.assertEquals(isValidActionInModelPrediction(actionType), false); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSetTest.java b/common/src/test/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSetTest.java index 55381895ee..22a549a1d1 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSetTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSetTest.java @@ -9,6 +9,7 @@ import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.ml.common.connector.ConnectorAction.ActionType; import org.opensearch.ml.common.dataset.MLInputDataset; public class RemoteInferenceInputDataSetTest { @@ -44,4 +45,24 @@ public void writeTo() throws IOException { Assert.assertEquals("test value1", inputDataSet2.getParameters().get("key1")); Assert.assertEquals("test value2", inputDataSet2.getParameters().get("key2")); } + + @Test + public void writeTo_withActionType() throws IOException { + Map parameters = new HashMap<>(); + parameters.put("key1", "test value1"); + parameters.put("key2", "test value2"); + ActionType actionType = ActionType.from("predict"); + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).actionType(actionType).build(); + + BytesStreamOutput output = new BytesStreamOutput(); + inputDataSet.writeTo(output); + StreamInput streamInput = output.bytes().streamInput(); + + RemoteInferenceInputDataSet inputDataSet2 = (RemoteInferenceInputDataSet) MLInputDataset.fromStream(streamInput); + Assert.assertEquals(REMOTE, inputDataSet2.getInputDataType()); + Assert.assertEquals(2, inputDataSet2.getParameters().size()); + Assert.assertEquals("test value1", inputDataSet2.getParameters().get("key1")); + Assert.assertEquals("test value2", inputDataSet2.getParameters().get("key2")); + Assert.assertEquals("PREDICT", inputDataSet2.getActionType().toString()); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java index d08c678634..a63f815232 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java @@ -27,12 +27,14 @@ import org.opensearch.ml.common.dataframe.DefaultDataFrame; import org.opensearch.ml.common.dataframe.DoubleValue; import org.opensearch.ml.common.dataframe.Row; -import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.SearchQueryInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.nlp.TextSimilarityMLInput; import org.opensearch.ml.common.input.parameter.regression.LinearRegressionParams; import org.opensearch.ml.common.output.model.ModelResultFilter; @@ -44,6 +46,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.function.Consumer; import java.util.function.Function; @@ -168,6 +171,40 @@ public void parse_NLPRelated_NullResultFilter() throws IOException { parse_NLPModel_NullResultFilter(FunctionName.SPARSE_ENCODING); } + @Test + public void parse_Remote_Model() throws IOException { + Map parameters = Map.of("TransformJobName", "new name"); + RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder() + .parameters(parameters) + .actionType(ConnectorAction.ActionType.PREDICT) + .build(); + + String expectedInputStr = "{\"algorithm\":\"REMOTE\",\"parameters\":{\"TransformJobName\":\"new name\"},\"action_type\":\"PREDICT\"}"; + + testParse(FunctionName.REMOTE, remoteInferenceInputDataSet, expectedInputStr, parsedInput -> { + assertNotNull(parsedInput.getInputDataset()); + RemoteInferenceInputDataSet parsedInputDataSet = (RemoteInferenceInputDataSet) parsedInput.getInputDataset(); + assertEquals(ConnectorAction.ActionType.PREDICT, parsedInputDataSet.getActionType()); + }); + } + + @Test + public void parse_Remote_Model_With_ActionType() throws IOException { + Map parameters = Map.of("TransformJobName", "new name"); + RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder() + .parameters(parameters) + .actionType(ConnectorAction.ActionType.BATCH_PREDICT) + .build(); + + String expectedInputStr = "{\"algorithm\":\"REMOTE\",\"parameters\":{\"TransformJobName\":\"new name\"},\"action_type\":\"BATCH_PREDICT\"}"; + + testParseWithActionType(FunctionName.REMOTE, remoteInferenceInputDataSet, ConnectorAction.ActionType.BATCH_PREDICT, expectedInputStr, parsedInput -> { + assertNotNull(parsedInput.getInputDataset()); + RemoteInferenceInputDataSet parsedInputDataSet = (RemoteInferenceInputDataSet) parsedInput.getInputDataset(); + assertEquals(ConnectorAction.ActionType.BATCH_PREDICT, parsedInputDataSet.getActionType()); + }); + } + private void testParse(FunctionName algorithm, MLInputDataset inputDataset, String expectedInputStr, Consumer verify) throws IOException { MLInput input = MLInput.builder().inputDataset(inputDataset).algorithm(algorithm).build(); XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); @@ -186,6 +223,24 @@ private void testParse(FunctionName algorithm, MLInputDataset inputDataset, Stri verify.accept(parsedInput); } + private void testParseWithActionType(FunctionName algorithm, MLInputDataset inputDataset, ConnectorAction.ActionType actionType, String expectedInputStr, Consumer verify) throws IOException { + MLInput input = MLInput.builder().inputDataset(inputDataset).algorithm(algorithm).build(); + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + input.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(builder); + String jsonStr = builder.toString(); + assertEquals(expectedInputStr, jsonStr); + + XContentParser parser = XContentType.JSON.xContent() + .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); + MLInput parsedInput = MLInput.parse(parser, algorithm.name(), actionType); + assertEquals(input.getFunctionName(), parsedInput.getFunctionName()); + assertEquals(input.getInputDataset().getInputDataType(), parsedInput.getInputDataset().getInputDataType()); + verify.accept(parsedInput); + } + @Test public void readInputStream_Success() throws IOException { readInputStream(input, parsedInput -> { diff --git a/common/src/test/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInputTest.java index a01a955e7a..759bf154b3 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInputTest.java @@ -39,10 +39,11 @@ public void constructor_stream() throws IOException { RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet)input.getInputDataset(); Assert.assertEquals(1, inputDataSet.getParameters().size()); Assert.assertEquals("hello world", inputDataSet.getParameters().get("prompt")); + Assert.assertEquals("BATCH_PREDICT", inputDataSet.getActionType().toString()); } private static RemoteInferenceMLInput createRemoteInferenceMLInput() throws IOException { - String jsonStr = "{ \"parameters\": { \"prompt\": \"hello world\" } }"; + String jsonStr = "{ \"parameters\": { \"prompt\": \"hello world\" }, \"action_type\": \"batch_predict\" }"; XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), null, jsonStr); parser.nextToken(); diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java index 4da9f9d68e..042a4a3a91 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java @@ -750,7 +750,6 @@ public void testGetSg_NoIndex_ThenFail() { interactionsIndex.getInteraction("iid", getListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(getListener, times(1)).onFailure(argCaptor.capture()); - System.out.println(argCaptor.getValue().getMessage()); assert (argCaptor .getValue() .getMessage() diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java index 0f208adb7d..f43f0ca0c3 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java @@ -18,6 +18,8 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorAction.ActionType; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.model.MLGuard; @@ -70,7 +72,12 @@ public void asyncPredict(MLInput mlInput, ActionListener actionL return; } try { - connectorExecutor.executeAction(PREDICT.name(), mlInput, actionListener); + ActionType actionType = null; + if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) { + actionType = ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getActionType(); + } + actionType = actionType == null ? ActionType.PREDICT : actionType; + connectorExecutor.executeAction(actionType.toString(), mlInput, actionListener); } catch (RuntimeException e) { log.error("Failed to call remote model.", e); actionListener.onFailure(e); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java index f6c9b76071..44d3f104cb 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java @@ -326,7 +326,6 @@ public void test_onComplete_error_http_status() { ArgumentCaptor captor = ArgumentCaptor.forClass(Exception.class); verify(actionListener, times(1)).onFailure(captor.capture()); assert captor.getValue() instanceof OpenSearchStatusException; - System.out.println(captor.getValue().getMessage()); assert captor.getValue().getMessage().contains("runtime error"); } @@ -350,7 +349,6 @@ public void test_onComplete_throttle_error_headers() { ArgumentCaptor captor = ArgumentCaptor.forClass(Exception.class); verify(actionListener, times(1)).onFailure(captor.capture()); assert captor.getValue() instanceof OpenSearchStatusException; - System.out.println(captor.getValue().getMessage()); assert captor.getValue().getMessage().contains(REMOTE_SERVICE_ERROR); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java index 5af116eb6f..72b841eb7b 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java @@ -11,6 +11,7 @@ import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; +import static org.opensearch.ml.utils.RestActionUtils.getActionTypeFromRestRequest; import static org.opensearch.ml.utils.RestActionUtils.getParameterId; import java.io.IOException; @@ -25,6 +26,7 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.connector.ConnectorAction.ActionType; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; @@ -69,7 +71,11 @@ public List routes() { RestRequest.Method.POST, String.format(Locale.ROOT, "%s/_predict/{%s}/{%s}", ML_BASE_URI, PARAMETER_ALGORITHM, PARAMETER_MODEL_ID) ), - new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/models/{%s}/_predict", ML_BASE_URI, PARAMETER_MODEL_ID)) + new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/models/{%s}/_predict", ML_BASE_URI, PARAMETER_MODEL_ID)), + new Route( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/models/{%s}/_batch_predict", ML_BASE_URI, PARAMETER_MODEL_ID) + ) ); } @@ -120,14 +126,18 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client */ @VisibleForTesting MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest request) throws IOException { + ActionType actionType = ActionType.from(getActionTypeFromRestRequest(request)); if (FunctionName.REMOTE.name().equals(algorithm) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); } else if (FunctionName.isDLModel(FunctionName.from(algorithm.toUpperCase())) && !mlFeatureEnabledSetting.isLocalModelEnabled()) { throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG); + } else if (!ActionType.isValidActionInModelPrediction(actionType)) { + throw new IllegalArgumentException("Wrong action type in the rest request path!"); } + XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - MLInput mlInput = MLInput.parse(parser, algorithm); + MLInput mlInput = MLInput.parse(parser, algorithm, actionType); return new MLPredictionTaskRequest(modelId, mlInput, null); } diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 1db559ed80..6daffd30fd 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -142,6 +142,7 @@ private MLCommonsSettings() {} ImmutableList .of( "^https://runtime\\.sagemaker\\..*[a-z0-9-]\\.amazonaws\\.com/.*$", + "^https://api\\.sagemaker\\..*[a-z0-9-]\\.amazonaws\\.com/.*$", "^https://api\\.openai\\.com/.*$", "^https://api\\.cohere\\.ai/.*$", "^https://bedrock-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$", diff --git a/plugin/src/main/java/org/opensearch/ml/stats/ActionName.java b/plugin/src/main/java/org/opensearch/ml/stats/ActionName.java index 91a091364f..8525793a31 100644 --- a/plugin/src/main/java/org/opensearch/ml/stats/ActionName.java +++ b/plugin/src/main/java/org/opensearch/ml/stats/ActionName.java @@ -12,7 +12,8 @@ public enum ActionName { EXECUTE, REGISTER, DEPLOY, - UNDEPLOY; + UNDEPLOY, + BATCH_PREDICT; public static ActionName from(String value) { try { diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index b341f4c9f5..72c43bd58f 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -47,8 +47,10 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLPredictionOutput; @@ -276,13 +278,12 @@ private String getPredictThreadPool(FunctionName functionName) { private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListener listener) { ActionListener internalListener = wrappedCleanupListener(listener, mlTask.getTaskId()); // track ML task count and add ML task into cache + ActionName actionName = getActionNameFromInput(mlInput); mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); - mlStats - .createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.PREDICT, MLActionLevelStat.ML_ACTION_REQUEST_COUNT) - .increment(); + mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), actionName, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment(); if (modelId != null) { - mlStats.createModelCounterStatIfAbsent(modelId, ActionName.PREDICT, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment(); + mlStats.createModelCounterStatIfAbsent(modelId, actionName, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment(); } mlTask.setState(MLTaskState.RUNNING); mlTaskManager.add(mlTask); @@ -305,7 +306,7 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe .workerNodes(Arrays.asList(clusterService.localNode().getId())) .build(); mlModelManager.deployModel(modelId, null, functionName, false, true, mlDeployTask, ActionListener.wrap(s -> { - runPredict(modelId, mlTask, mlInput, functionName, internalListener); + runPredict(modelId, mlTask, mlInput, functionName, actionName, internalListener); }, e -> { log.error("Failed to auto deploy model " + modelId, e); internalListener.onFailure(e); @@ -313,7 +314,7 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe return; } - runPredict(modelId, mlTask, mlInput, functionName, internalListener); + runPredict(modelId, mlTask, mlInput, functionName, actionName, internalListener); } private void runPredict( @@ -321,6 +322,7 @@ private void runPredict( MLTask mlTask, MLInput mlInput, FunctionName algorithm, + ActionName actionName, ActionListener internalListener ) { // run predict @@ -340,7 +342,7 @@ private void runPredict( handleAsyncMLTaskComplete(mlTask); mlModelManager.trackPredictDuration(modelId, startTime); internalListener.onResponse(output); - }, e -> handlePredictFailure(mlTask, internalListener, e, false, modelId)); + }, e -> handlePredictFailure(mlTask, internalListener, e, false, modelId, actionName)); predictor.asyncPredict(mlInput, trackPredictDurationListener); } else { MLOutput output = mlModelManager.trackPredictDuration(modelId, () -> predictor.predict(mlInput)); @@ -357,7 +359,7 @@ private void runPredict( return; } catch (Exception e) { log.error("Failed to predict model " + modelId, e); - handlePredictFailure(mlTask, internalListener, e, false, modelId); + handlePredictFailure(mlTask, internalListener, e, false, modelId, actionName); return; } } else if (FunctionName.needDeployFirst(algorithm)) { @@ -388,7 +390,7 @@ private void runPredict( OpenSearchException e = new OpenSearchException( "User: " + requestUser.getName() + " does not have permissions to run predict by model: " + modelId ); - handlePredictFailure(mlTask, internalListener, e, false, modelId); + handlePredictFailure(mlTask, internalListener, e, false, modelId, actionName); return; } // run predict @@ -413,7 +415,7 @@ private void runPredict( }, e -> { log.error("Failed to predict " + mlInput.getAlgorithm() + ", modelId: " + mlTask.getModelId(), e); - handlePredictFailure(mlTask, internalListener, e, true, modelId); + handlePredictFailure(mlTask, internalListener, e, true, modelId, actionName); }); GetRequest getRequest = new GetRequest(ML_MODEL_INDEX, mlTask.getModelId()); client @@ -426,12 +428,12 @@ private void runPredict( ); } catch (Exception e) { log.error("Failed to get model " + mlTask.getModelId(), e); - handlePredictFailure(mlTask, internalListener, e, true, modelId); + handlePredictFailure(mlTask, internalListener, e, true, modelId, actionName); } } else { IllegalArgumentException e = new IllegalArgumentException("ModelId is invalid"); log.error("ModelId is invalid", e); - handlePredictFailure(mlTask, internalListener, e, false, modelId); + handlePredictFailure(mlTask, internalListener, e, false, modelId, actionName); } } @@ -445,19 +447,26 @@ private void handlePredictFailure( ActionListener listener, Exception e, boolean trackFailure, - String modelId + String modelId, + ActionName actionName ) { if (trackFailure) { - mlStats - .createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.PREDICT, MLActionLevelStat.ML_ACTION_FAILURE_COUNT) - .increment(); - mlStats.createModelCounterStatIfAbsent(modelId, ActionName.PREDICT, MLActionLevelStat.ML_ACTION_FAILURE_COUNT); + mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), actionName, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment(); + mlStats.createModelCounterStatIfAbsent(modelId, actionName, MLActionLevelStat.ML_ACTION_FAILURE_COUNT); mlStats.getStat(MLNodeLevelStat.ML_FAILURE_COUNT).increment(); } handleAsyncMLTaskFailure(mlTask, e); listener.onFailure(e); } + private ActionName getActionNameFromInput(MLInput mlInput) { + ConnectorAction.ActionType actionType = null; + if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) { + actionType = ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getActionType(); + } + return (actionType == null) ? ActionName.PREDICT : ActionName.from(actionType.toString()); + } + public void validateOutputSchema(String modelId, ModelTensorOutput output) { if (mlModelManager.getModelInterface(modelId) != null && mlModelManager.getModelInterface(modelId).get("output") != null) { String outputSchemaString = mlModelManager.getModelInterface(modelId).get("output"); diff --git a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java index ce1305656e..71b4e7d08d 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java @@ -39,6 +39,7 @@ import org.opensearch.core.common.Strings; import org.opensearch.core.rest.RestStatus; import org.opensearch.index.IndexNotFoundException; +import org.opensearch.ml.common.connector.ConnectorAction.ActionType; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; @@ -308,4 +309,24 @@ public static void wrapListenerToHandleSearchIndexNotFound(Exception e, ActionLi } } + /** + * Determine the ActionType from the restful request by checking the url path and method name so there's no need + * to specify the ActionType in the request body. For example, /_plugins/_ml/models/{model_id}/_predict will return + * PREDICT as the ActionType, and /_plugins/_ml/models/{model_id}/_batch_predict will return BATCH_PREDICT. + * @param request A Restful request that needs to determine the ActionType from the path. + * @return parsed user object + */ + public static String getActionTypeFromRestRequest(RestRequest request) { + String path = request.path(); + String[] segments = path.split("/"); + String methodName = segments[segments.length - 1]; + methodName = methodName.startsWith("_") ? methodName.substring(1) : methodName; + + // find the action type for "/_plugins/_ml/_predict//" + if (!ActionType.isValidAction(methodName) && segments.length > 3) { + methodName = segments[3]; + methodName = methodName.contains("_") ? methodName.split("_")[1] : methodName; + } + return methodName; + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java index d34e0fd00e..001b3709a8 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java @@ -11,7 +11,10 @@ import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; +import static org.opensearch.ml.utils.TestHelper.getBatchRestRequest; +import static org.opensearch.ml.utils.TestHelper.getBatchRestRequest_WrongActionType; import static org.opensearch.ml.utils.TestHelper.getKMeansRestRequest; +import static org.opensearch.ml.utils.TestHelper.verifyParsedBatchMLInput; import static org.opensearch.ml.utils.TestHelper.verifyParsedKMeansMLInput; import java.io.IOException; @@ -107,6 +110,15 @@ public void testRoutes() { assertEquals("/_plugins/_ml/_predict/{algorithm}/{model_id}", route.getPath()); } + public void testRoutes_Batch() { + List routes = restMLPredictionAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(2); + assertEquals(RestRequest.Method.POST, route.getMethod()); + assertEquals("/_plugins/_ml/models/{model_id}/_batch_predict", route.getPath()); + } + public void testGetRequest() throws IOException { RestRequest request = getRestRequest_PredictModel(); MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction.getRequest("modelId", FunctionName.KMEANS.name(), request); @@ -143,6 +155,23 @@ public void testPrepareRequest() throws Exception { verifyParsedKMeansMLInput(mlInput); } + public void testPrepareBatchRequest() throws Exception { + RestRequest request = getBatchRestRequest(); + restMLPredictionAction.handleRequest(request, channel, client); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class); + verify(client, times(1)).execute(eq(MLPredictionTaskAction.INSTANCE), argumentCaptor.capture(), any()); + MLInput mlInput = argumentCaptor.getValue().getMlInput(); + verifyParsedBatchMLInput(mlInput); + } + + public void testPrepareBatchRequest_WrongActionType() throws Exception { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Wrong Action Type"); + + RestRequest request = getBatchRestRequest_WrongActionType(); + restMLPredictionAction.getRequest("model id", "remote", request); + } + @Ignore public void testPrepareRequest_EmptyAlgorithm() throws Exception { MLModel model = MLModel.builder().algorithm(FunctionName.BATCH_RCF).build(); diff --git a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java index 9ecf3be9e2..513b2497eb 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java @@ -16,6 +16,7 @@ import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_AGENT_ID; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; import java.io.BufferedReader; import java.io.File; @@ -73,8 +74,10 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.dataset.SearchQueryInputDataset; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.execute.metricscorrelation.MetricsCorrelationInput; import org.opensearch.ml.common.input.execute.samplecalculator.LocalSampleCalculatorInput; @@ -211,6 +214,33 @@ public static RestRequest getKMeansRestRequest() { + "\"COSINE\"},\"input_query\":{\"_source\":[\"petal_length_in_cm\",\"petal_width_in_cm\"]," + "\"size\":10000},\"input_index\":[\"iris_data\"]}"; RestRequest request = new FakeRestRequest.Builder(getXContentRegistry()) + .withPath("/_plugins/_ml/models/{model_id}}/_predict") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } + + public static RestRequest getBatchRestRequest() { + Map params = new HashMap<>(); + params.put(PARAMETER_MODEL_ID, "sample model"); + params.put(PARAMETER_ALGORITHM, "remote"); + final String requestContent = "{\"parameters\":{\"TransformJobName\":\"SM-offline-batch-transform-07-17-14-30\"}}"; + RestRequest request = new FakeRestRequest.Builder(getXContentRegistry()) + .withPath("/_plugins/_ml/models/{model_id}/_batch_predict") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } + + public static RestRequest getBatchRestRequest_WrongActionType() { + Map params = new HashMap<>(); + params.put(PARAMETER_MODEL_ID, "sample model"); + params.put(PARAMETER_ALGORITHM, "remote"); + final String requestContent = "{\"parameters\":{\"TransformJobName\":\"SM-offline-batch-transform-07-17-14-30\"}}"; + RestRequest request = new FakeRestRequest.Builder(getXContentRegistry()) + .withPath("/_plugins/_ml/models/{model_id}}/_BadType") .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build(); @@ -356,6 +386,13 @@ public static void verifyParsedKMeansMLInput(MLInput mlInput) { assertEquals(3, kMeansParams.getCentroids().intValue()); } + public static void verifyParsedBatchMLInput(MLInput mlInput) { + assertEquals(FunctionName.REMOTE, mlInput.getAlgorithm()); + assertEquals(MLInputDataType.REMOTE, mlInput.getInputDataset().getInputDataType()); + RemoteInferenceInputDataSet inputDataset = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + assertEquals(ConnectorAction.ActionType.BATCH_PREDICT, inputDataset.getActionType()); + } + private static NamedXContentRegistry getXContentRegistry() { SearchModule searchModule = new SearchModule(Settings.EMPTY, Collections.emptyList()); List entries = new ArrayList<>();