diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 192b7a9737..93564a7d4b 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -538,4 +538,5 @@ public class CommonValue { public static final Version VERSION_2_12_0 = Version.fromString("2.12.0"); public static final Version VERSION_2_13_0 = Version.fromString("2.13.0"); public static final Version VERSION_2_14_0 = Version.fromString("2.14.0"); + public static final Version VERSION_2_16_0 = Version.fromString("2.16.0"); } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java index 1562894cee..b0373ca4e9 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java @@ -193,7 +193,7 @@ public static ActionType from(String value) { try { return ActionType.valueOf(value.toUpperCase(Locale.ROOT)); } catch (Exception e) { - throw new IllegalArgumentException("Wrong Action Type"); + throw new IllegalArgumentException("Wrong Action Type of " + value); } } @@ -205,5 +205,14 @@ public static ActionType from(String value) { public static boolean isValidActionInModelPrediction(ActionType actionType) { return MODEL_SUPPORT_ACTIONS.contains(actionType); } + + public static boolean isValidAction(String action) { + try { + ActionType.valueOf(action.toUpperCase()); + return true; + } catch (IllegalArgumentException e) { + return false; + } + } } } diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java b/common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java index ad145213bc..09a6e4f269 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java +++ b/common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java @@ -8,8 +8,10 @@ import lombok.Builder; import lombok.Getter; import lombok.Setter; +import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.connector.ConnectorAction.ActionType; import org.opensearch.ml.common.annotation.InputDataSet; import org.opensearch.ml.common.dataset.MLInputDataType; @@ -21,7 +23,7 @@ @Getter @InputDataSet(MLInputDataType.REMOTE) public class RemoteInferenceInputDataSet extends MLInputDataset { - + private static final Version MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG = CommonValue.VERSION_2_16_0; @Setter private Map parameters; @Setter @@ -40,30 +42,36 @@ public RemoteInferenceInputDataSet(Map parameters) { public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException { super(MLInputDataType.REMOTE); + Version streamInputVersion = streamInput.getVersion(); if (streamInput.readBoolean()) { parameters = streamInput.readMap(s -> s.readString(), s-> s.readString()); } - if (streamInput.readBoolean()) { - actionType = streamInput.readEnum(ActionType.class); - } else { - this.actionType = null; + if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG)) { + if (streamInput.readBoolean()) { + actionType = streamInput.readEnum(ActionType.class); + } else { + this.actionType = null; + } } } @Override public void writeTo(StreamOutput streamOutput) throws IOException { super.writeTo(streamOutput); + Version streamOutputVersion = streamOutput.getVersion(); if (parameters != null) { streamOutput.writeBoolean(true); streamOutput.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeString); } else { streamOutput.writeBoolean(false); } - if (actionType != null) { - streamOutput.writeBoolean(true); - streamOutput.writeEnum(actionType); - } else { - streamOutput.writeBoolean(false); + if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG)) { + if (actionType != null) { + streamOutput.writeBoolean(true); + streamOutput.writeEnum(actionType); + } else { + streamOutput.writeBoolean(false); + } } } diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index e2190f8a23..2028d4770c 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -19,8 +19,6 @@ import java.util.Arrays; import java.util.UUID; -import javax.swing.*; - import org.opensearch.OpenSearchException; import org.opensearch.OpenSearchStatusException; import org.opensearch.ResourceNotFoundException; diff --git a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java index 2a3e091281..1f7839ee86 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java @@ -39,6 +39,7 @@ import org.opensearch.core.common.Strings; import org.opensearch.core.rest.RestStatus; import org.opensearch.index.IndexNotFoundException; +import org.opensearch.ml.common.connector.ConnectorAction.ActionType; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; @@ -313,8 +314,12 @@ public static String getActionTypeFromRestRequest(RestRequest request) { String path = request.path(); String[] segments = path.split("/"); String methodName = segments[segments.length - 1]; - if (methodName.contains("_")) { - methodName = methodName.split("_")[1]; + methodName = methodName.contains("_") ? methodName.split("_")[1] : methodName; + + // find the action type for "/_plugins/_ml/_predict//" + if (!ActionType.isValidAction(methodName) && segments.length > 3) { + methodName = segments[3]; + methodName = methodName.contains("_") ? methodName.split("_")[1] : methodName; } return methodName; } diff --git a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java index 88488df351..01f4563c79 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java @@ -212,6 +212,7 @@ public static RestRequest getKMeansRestRequest() { + "\"COSINE\"},\"input_query\":{\"_source\":[\"petal_length_in_cm\",\"petal_width_in_cm\"]," + "\"size\":10000},\"input_index\":[\"iris_data\"]}"; RestRequest request = new FakeRestRequest.Builder(getXContentRegistry()) + .withPath("/_plugins/_ml/models/{model_id}}/_predict") .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build();