diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java index b0373ca4e9..9be290d126 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java @@ -187,7 +187,7 @@ public static ConnectorAction parse(XContentParser parser) throws IOException { public enum ActionType { PREDICT, EXECUTE, - BATCH; + BATCH_PREDICT; public static ActionType from(String value) { try { @@ -199,7 +199,7 @@ public static ActionType from(String value) { private static final HashSet MODEL_SUPPORT_ACTIONS = new HashSet<>(Set.of( PREDICT, - BATCH + BATCH_PREDICT )); public static boolean isValidActionInModelPrediction(ActionType actionType) { diff --git a/common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java index 9b7992faf6..d12cb0e2da 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java @@ -186,15 +186,15 @@ public void parse_Remote_Model_With_ActionType() throws IOException { Map parameters = Map.of("TransformJobName", "new name"); RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder() .parameters(parameters) - .actionType(ConnectorAction.ActionType.BATCH) + .actionType(ConnectorAction.ActionType.BATCH_PREDICT) .build(); - String expectedInputStr = "{\"algorithm\":\"REMOTE\",\"parameters\":{\"TransformJobName\":\"new name\"},\"action_type\":\"BATCH\"}"; + String expectedInputStr = "{\"algorithm\":\"REMOTE\",\"parameters\":{\"TransformJobName\":\"new name\"},\"action_type\":\"BATCH_PREDICT\"}"; - testParseWithActionType(FunctionName.REMOTE, remoteInferenceInputDataSet, ConnectorAction.ActionType.BATCH, expectedInputStr, parsedInput -> { + testParseWithActionType(FunctionName.REMOTE, remoteInferenceInputDataSet, ConnectorAction.ActionType.BATCH_PREDICT, expectedInputStr, parsedInput -> { assertNotNull(parsedInput.getInputDataset()); RemoteInferenceInputDataSet parsedInputDataSet = (RemoteInferenceInputDataSet) parsedInput.getInputDataset(); - assertEquals(ConnectorAction.ActionType.BATCH, parsedInputDataSet.getActionType()); + assertEquals(ConnectorAction.ActionType.BATCH_PREDICT, parsedInputDataSet.getActionType()); }); } diff --git a/common/src/test/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInputTest.java index 153f549fa2..759bf154b3 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInputTest.java @@ -39,11 +39,11 @@ public void constructor_stream() throws IOException { RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet)input.getInputDataset(); Assert.assertEquals(1, inputDataSet.getParameters().size()); Assert.assertEquals("hello world", inputDataSet.getParameters().get("prompt")); - Assert.assertEquals("BATCH", inputDataSet.getActionType().toString()); + Assert.assertEquals("BATCH_PREDICT", inputDataSet.getActionType().toString()); } private static RemoteInferenceMLInput createRemoteInferenceMLInput() throws IOException { - String jsonStr = "{ \"parameters\": { \"prompt\": \"hello world\" }, \"action_type\": \"BATCH\" }"; + String jsonStr = "{ \"parameters\": { \"prompt\": \"hello world\" }, \"action_type\": \"batch_predict\" }"; XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), null, jsonStr); parser.nextToken(); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java index 22ad209cfd..82c72e11a2 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java @@ -72,7 +72,10 @@ public List routes() { String.format(Locale.ROOT, "%s/_predict/{%s}/{%s}", ML_BASE_URI, PARAMETER_ALGORITHM, PARAMETER_MODEL_ID) ), new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/models/{%s}/_predict", ML_BASE_URI, PARAMETER_MODEL_ID)), - new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/models/{%s}/_batch", ML_BASE_URI, PARAMETER_MODEL_ID)) + new Route( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/models/{%s}/_batch_predict", ML_BASE_URI, PARAMETER_MODEL_ID) + ) ); } @@ -124,11 +127,13 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client @VisibleForTesting MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest request) throws IOException { ActionType actionType = ActionType.from(getActionTypeFromRestRequest(request)); + System.out.println("actionType is " + actionType); if (FunctionName.REMOTE.name().equals(algorithm) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); } else if (FunctionName.isDLModel(FunctionName.from(algorithm.toUpperCase())) && !mlFeatureEnabledSetting.isLocalModelEnabled()) { throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG); } else if (!ActionType.isValidActionInModelPrediction(actionType)) { + System.out.println(actionType.toString()); throw new IllegalArgumentException("Wrong action type in the rest request path!"); } diff --git a/plugin/src/main/java/org/opensearch/ml/stats/ActionName.java b/plugin/src/main/java/org/opensearch/ml/stats/ActionName.java index 7431c08ae7..8525793a31 100644 --- a/plugin/src/main/java/org/opensearch/ml/stats/ActionName.java +++ b/plugin/src/main/java/org/opensearch/ml/stats/ActionName.java @@ -13,7 +13,7 @@ public enum ActionName { REGISTER, DEPLOY, UNDEPLOY, - BATCH; + BATCH_PREDICT; public static ActionName from(String value) { try { diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 2028d4770c..72c43bd58f 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -464,11 +464,7 @@ private ActionName getActionNameFromInput(MLInput mlInput) { if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) { actionType = ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getActionType(); } - if (actionType == null) { - return ActionName.PREDICT; - } else { - return ActionName.from(actionType.toString()); - } + return (actionType == null) ? ActionName.PREDICT : ActionName.from(actionType.toString()); } public void validateOutputSchema(String modelId, ModelTensorOutput output) { diff --git a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java index 1f7839ee86..5f5f567eb8 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java @@ -310,11 +310,19 @@ public static void wrapListenerToHandleSearchIndexNotFound(Exception e, ActionLi } } + /** + * Determine the ActionType from the restful request by checking the url path and method name so there's no need + * to specify the ActionType in the request body. For example, /_plugins/_ml/models/{model_id}/_predict will return + * PREDICT as the ActionType, and /_plugins/_ml/models/{model_id}/_batch_predict will return BATCH_PREDICT. + * @param request A Restful request that needs to determine the ActionType from the path. + * @return parsed user object + */ public static String getActionTypeFromRestRequest(RestRequest request) { String path = request.path(); + System.out.println("path is " + path); String[] segments = path.split("/"); String methodName = segments[segments.length - 1]; - methodName = methodName.contains("_") ? methodName.split("_")[1] : methodName; + methodName = methodName.startsWith("_") ? methodName.substring(1) : methodName; // find the action type for "/_plugins/_ml/_predict//" if (!ActionType.isValidAction(methodName) && segments.length > 3) { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java index ab22504be7..001b3709a8 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java @@ -116,7 +116,7 @@ public void testRoutes_Batch() { assertFalse(routes.isEmpty()); RestHandler.Route route = routes.get(2); assertEquals(RestRequest.Method.POST, route.getMethod()); - assertEquals("/_plugins/_ml/models/{model_id}/_batch", route.getPath()); + assertEquals("/_plugins/_ml/models/{model_id}/_batch_predict", route.getPath()); } public void testGetRequest() throws IOException { diff --git a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java index 01f4563c79..ca5046fa0b 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java @@ -225,7 +225,7 @@ public static RestRequest getBatchRestRequest() { params.put(PARAMETER_ALGORITHM, "remote"); final String requestContent = "{\"parameters\":{\"TransformJobName\":\"SM-offline-batch-transform-07-17-14-30\"}}"; RestRequest request = new FakeRestRequest.Builder(getXContentRegistry()) - .withPath("/_plugins/_ml/models/{model_id}}/_batch") + .withPath("/_plugins/_ml/models/{model_id}/_batch_predict") .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build(); @@ -388,7 +388,7 @@ public static void verifyParsedBatchMLInput(MLInput mlInput) { assertEquals(FunctionName.REMOTE, mlInput.getAlgorithm()); assertEquals(MLInputDataType.REMOTE, mlInput.getInputDataset().getInputDataType()); RemoteInferenceInputDataSet inputDataset = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); - assertEquals(ConnectorAction.ActionType.BATCH, inputDataset.getActionType()); + assertEquals(ConnectorAction.ActionType.BATCH_PREDICT, inputDataset.getActionType()); } private static NamedXContentRegistry getXContentRegistry() {