Skip to content

Commit

Permalink
add bwx for actiontype
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed Jul 22, 2024
1 parent 643ab80 commit a88775f
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -538,4 +538,5 @@ public class CommonValue {
public static final Version VERSION_2_12_0 = Version.fromString("2.12.0");
public static final Version VERSION_2_13_0 = Version.fromString("2.13.0");
public static final Version VERSION_2_14_0 = Version.fromString("2.14.0");
public static final Version VERSION_2_16_0 = Version.fromString("2.16.0");
}
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ public static ActionType from(String value) {
try {
return ActionType.valueOf(value.toUpperCase(Locale.ROOT));
} catch (Exception e) {
throw new IllegalArgumentException("Wrong Action Type");
throw new IllegalArgumentException("Wrong Action Type of " + value);
}
}

Expand All @@ -205,5 +205,14 @@ public static ActionType from(String value) {
public static boolean isValidActionInModelPrediction(ActionType actionType) {
return MODEL_SUPPORT_ACTIONS.contains(actionType);
}

public static boolean isValidAction(String action) {
try {
ActionType.valueOf(action.toUpperCase());
return true;
} catch (IllegalArgumentException e) {
return false;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import lombok.Builder;
import lombok.Getter;
import lombok.Setter;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.CommonValue;
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
import org.opensearch.ml.common.annotation.InputDataSet;
import org.opensearch.ml.common.dataset.MLInputDataType;
Expand All @@ -21,7 +23,7 @@
@Getter
@InputDataSet(MLInputDataType.REMOTE)
public class RemoteInferenceInputDataSet extends MLInputDataset {

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG = CommonValue.VERSION_2_16_0;
@Setter
private Map<String, String> parameters;
@Setter
Expand All @@ -40,30 +42,36 @@ public RemoteInferenceInputDataSet(Map<String, String> parameters) {

public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException {
super(MLInputDataType.REMOTE);
Version streamInputVersion = streamInput.getVersion();
if (streamInput.readBoolean()) {
parameters = streamInput.readMap(s -> s.readString(), s-> s.readString());
}
if (streamInput.readBoolean()) {
actionType = streamInput.readEnum(ActionType.class);
} else {
this.actionType = null;
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG)) {
if (streamInput.readBoolean()) {
actionType = streamInput.readEnum(ActionType.class);
} else {
this.actionType = null;
}
}
}

@Override
public void writeTo(StreamOutput streamOutput) throws IOException {
super.writeTo(streamOutput);
Version streamOutputVersion = streamOutput.getVersion();
if (parameters != null) {
streamOutput.writeBoolean(true);
streamOutput.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeString);
} else {
streamOutput.writeBoolean(false);
}
if (actionType != null) {
streamOutput.writeBoolean(true);
streamOutput.writeEnum(actionType);
} else {
streamOutput.writeBoolean(false);
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG)) {
if (actionType != null) {
streamOutput.writeBoolean(true);
streamOutput.writeEnum(actionType);
} else {
streamOutput.writeBoolean(false);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
import java.util.Arrays;
import java.util.UUID;

import javax.swing.*;

import org.opensearch.OpenSearchException;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.ResourceNotFoundException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.opensearch.core.common.Strings;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestRequest;
Expand Down Expand Up @@ -313,8 +314,12 @@ public static String getActionTypeFromRestRequest(RestRequest request) {
String path = request.path();
String[] segments = path.split("/");
String methodName = segments[segments.length - 1];
if (methodName.contains("_")) {
methodName = methodName.split("_")[1];
methodName = methodName.contains("_") ? methodName.split("_")[1] : methodName;

// find the action type for "/_plugins/_ml/_predict/<algorithm>/<model_id>"
if (!ActionType.isValidAction(methodName) && segments.length > 3) {
methodName = segments[3];
methodName = methodName.contains("_") ? methodName.split("_")[1] : methodName;
}
return methodName;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ public static RestRequest getKMeansRestRequest() {
+ "\"COSINE\"},\"input_query\":{\"_source\":[\"petal_length_in_cm\",\"petal_width_in_cm\"],"
+ "\"size\":10000},\"input_index\":[\"iris_data\"]}";
RestRequest request = new FakeRestRequest.Builder(getXContentRegistry())
.withPath("/_plugins/_ml/models/{model_id}}/_predict")
.withParams(params)
.withContent(new BytesArray(requestContent), XContentType.JSON)
.build();
Expand Down

0 comments on commit a88775f

Please sign in to comment.