Skip to content

Commit

Permalink
address more comments
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed Jul 22, 2024
1 parent a88775f commit 489ee29
Show file tree
Hide file tree
Showing 9 changed files with 28 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ public static ConnectorAction parse(XContentParser parser) throws IOException {
public enum ActionType {
PREDICT,
EXECUTE,
BATCH;
BATCH_PREDICT;

public static ActionType from(String value) {
try {
Expand All @@ -199,7 +199,7 @@ public static ActionType from(String value) {

private static final HashSet<ActionType> MODEL_SUPPORT_ACTIONS = new HashSet<>(Set.of(
PREDICT,
BATCH
BATCH_PREDICT
));

public static boolean isValidActionInModelPrediction(ActionType actionType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,15 @@ public void parse_Remote_Model_With_ActionType() throws IOException {
Map<String, String> parameters = Map.of("TransformJobName", "new name");
RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder()
.parameters(parameters)
.actionType(ConnectorAction.ActionType.BATCH)
.actionType(ConnectorAction.ActionType.BATCH_PREDICT)
.build();

String expectedInputStr = "{\"algorithm\":\"REMOTE\",\"parameters\":{\"TransformJobName\":\"new name\"},\"action_type\":\"BATCH\"}";
String expectedInputStr = "{\"algorithm\":\"REMOTE\",\"parameters\":{\"TransformJobName\":\"new name\"},\"action_type\":\"BATCH_PREDICT\"}";

testParseWithActionType(FunctionName.REMOTE, remoteInferenceInputDataSet, ConnectorAction.ActionType.BATCH, expectedInputStr, parsedInput -> {
testParseWithActionType(FunctionName.REMOTE, remoteInferenceInputDataSet, ConnectorAction.ActionType.BATCH_PREDICT, expectedInputStr, parsedInput -> {
assertNotNull(parsedInput.getInputDataset());
RemoteInferenceInputDataSet parsedInputDataSet = (RemoteInferenceInputDataSet) parsedInput.getInputDataset();
assertEquals(ConnectorAction.ActionType.BATCH, parsedInputDataSet.getActionType());
assertEquals(ConnectorAction.ActionType.BATCH_PREDICT, parsedInputDataSet.getActionType());
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ public void constructor_stream() throws IOException {
RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet)input.getInputDataset();
Assert.assertEquals(1, inputDataSet.getParameters().size());
Assert.assertEquals("hello world", inputDataSet.getParameters().get("prompt"));
Assert.assertEquals("BATCH", inputDataSet.getActionType().toString());
Assert.assertEquals("BATCH_PREDICT", inputDataSet.getActionType().toString());
}

private static RemoteInferenceMLInput createRemoteInferenceMLInput() throws IOException {
String jsonStr = "{ \"parameters\": { \"prompt\": \"hello world\" }, \"action_type\": \"BATCH\" }";
String jsonStr = "{ \"parameters\": { \"prompt\": \"hello world\" }, \"action_type\": \"batch_predict\" }";
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ public List<Route> routes() {
String.format(Locale.ROOT, "%s/_predict/{%s}/{%s}", ML_BASE_URI, PARAMETER_ALGORITHM, PARAMETER_MODEL_ID)
),
new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/models/{%s}/_predict", ML_BASE_URI, PARAMETER_MODEL_ID)),
new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/models/{%s}/_batch", ML_BASE_URI, PARAMETER_MODEL_ID))
new Route(
RestRequest.Method.POST,
String.format(Locale.ROOT, "%s/models/{%s}/_batch_predict", ML_BASE_URI, PARAMETER_MODEL_ID)
)
);
}

Expand Down Expand Up @@ -124,11 +127,13 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
@VisibleForTesting
MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest request) throws IOException {
ActionType actionType = ActionType.from(getActionTypeFromRestRequest(request));
System.out.println("actionType is " + actionType);
if (FunctionName.REMOTE.name().equals(algorithm) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
} else if (FunctionName.isDLModel(FunctionName.from(algorithm.toUpperCase())) && !mlFeatureEnabledSetting.isLocalModelEnabled()) {
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
} else if (!ActionType.isValidActionInModelPrediction(actionType)) {
System.out.println(actionType.toString());
throw new IllegalArgumentException("Wrong action type in the rest request path!");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public enum ActionName {
REGISTER,
DEPLOY,
UNDEPLOY,
BATCH;
BATCH_PREDICT;

public static ActionName from(String value) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,11 +464,7 @@ private ActionName getActionNameFromInput(MLInput mlInput) {
if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
actionType = ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getActionType();
}
if (actionType == null) {
return ActionName.PREDICT;
} else {
return ActionName.from(actionType.toString());
}
return (actionType == null) ? ActionName.PREDICT : ActionName.from(actionType.toString());
}

public void validateOutputSchema(String modelId, ModelTensorOutput output) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,11 +310,19 @@ public static void wrapListenerToHandleSearchIndexNotFound(Exception e, ActionLi
}
}

/**
* Determine the ActionType from the restful request by checking the url path and method name so there's no need
* to specify the ActionType in the request body. For example, /_plugins/_ml/models/{model_id}/_predict will return
* PREDICT as the ActionType, and /_plugins/_ml/models/{model_id}/_batch_predict will return BATCH_PREDICT.
* @param request A Restful request that needs to determine the ActionType from the path.
* @return parsed user object
*/
public static String getActionTypeFromRestRequest(RestRequest request) {
String path = request.path();
System.out.println("path is " + path);
String[] segments = path.split("/");
String methodName = segments[segments.length - 1];
methodName = methodName.contains("_") ? methodName.split("_")[1] : methodName;
methodName = methodName.startsWith("_") ? methodName.substring(1) : methodName;

// find the action type for "/_plugins/_ml/_predict/<algorithm>/<model_id>"
if (!ActionType.isValidAction(methodName) && segments.length > 3) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ public void testRoutes_Batch() {
assertFalse(routes.isEmpty());
RestHandler.Route route = routes.get(2);
assertEquals(RestRequest.Method.POST, route.getMethod());
assertEquals("/_plugins/_ml/models/{model_id}/_batch", route.getPath());
assertEquals("/_plugins/_ml/models/{model_id}/_batch_predict", route.getPath());
}

public void testGetRequest() throws IOException {
Expand Down
4 changes: 2 additions & 2 deletions plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ public static RestRequest getBatchRestRequest() {
params.put(PARAMETER_ALGORITHM, "remote");
final String requestContent = "{\"parameters\":{\"TransformJobName\":\"SM-offline-batch-transform-07-17-14-30\"}}";
RestRequest request = new FakeRestRequest.Builder(getXContentRegistry())
.withPath("/_plugins/_ml/models/{model_id}}/_batch")
.withPath("/_plugins/_ml/models/{model_id}/_batch_predict")
.withParams(params)
.withContent(new BytesArray(requestContent), XContentType.JSON)
.build();
Expand Down Expand Up @@ -388,7 +388,7 @@ public static void verifyParsedBatchMLInput(MLInput mlInput) {
assertEquals(FunctionName.REMOTE, mlInput.getAlgorithm());
assertEquals(MLInputDataType.REMOTE, mlInput.getInputDataset().getInputDataType());
RemoteInferenceInputDataSet inputDataset = (RemoteInferenceInputDataSet) mlInput.getInputDataset();
assertEquals(ConnectorAction.ActionType.BATCH, inputDataset.getActionType());
assertEquals(ConnectorAction.ActionType.BATCH_PREDICT, inputDataset.getActionType());
}

private static NamedXContentRegistry getXContentRegistry() {
Expand Down

0 comments on commit 489ee29

Please sign in to comment.