Skip to content

Commit

Permalink
Add Batch Prediction Mode in the Connector Framework for batch infere…
Browse files Browse the repository at this point in the history
…nce (opensearch-project#2661)

* add batch predict job actiontype in connector

Signed-off-by: Xun Zhang <[email protected]>

* remove async and streaming mode temporarily

Signed-off-by: Xun Zhang <[email protected]>

* rename predict mode to action type

Signed-off-by: Xun Zhang <[email protected]>

* use method name in the url path for action type

Signed-off-by: Xun Zhang <[email protected]>

* add stats for actionType and more UTs

Signed-off-by: Xun Zhang <[email protected]>

* add bwx for actiontype

Signed-off-by: Xun Zhang <[email protected]>

* address more comments

Signed-off-by: Xun Zhang <[email protected]>

---------

Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt authored and mingshl committed Jul 24, 2024
1 parent 3655235 commit 14c2363
Show file tree
Hide file tree
Showing 17 changed files with 317 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -538,4 +538,5 @@ public class CommonValue {
public static final Version VERSION_2_12_0 = Version.fromString("2.12.0");
public static final Version VERSION_2_13_0 = Version.fromString("2.13.0");
public static final Version VERSION_2_14_0 = Version.fromString("2.14.0");
public static final Version VERSION_2_16_0 = Version.fromString("2.16.0");
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;

import java.io.IOException;
import java.util.HashSet;
import java.util.Locale;
import java.util.Map;
import java.util.Set;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

Expand Down Expand Up @@ -183,6 +186,33 @@ public static ConnectorAction parse(XContentParser parser) throws IOException {

public enum ActionType {
PREDICT,
EXECUTE
EXECUTE,
BATCH_PREDICT;

public static ActionType from(String value) {
try {
return ActionType.valueOf(value.toUpperCase(Locale.ROOT));
} catch (Exception e) {
throw new IllegalArgumentException("Wrong Action Type of " + value);
}
}

private static final HashSet<ActionType> MODEL_SUPPORT_ACTIONS = new HashSet<>(Set.of(
PREDICT,
BATCH_PREDICT
));

public static boolean isValidActionInModelPrediction(ActionType actionType) {
return MODEL_SUPPORT_ACTIONS.contains(actionType);
}

public static boolean isValidAction(String action) {
try {
ActionType.valueOf(action.toUpperCase());
return true;
} catch (IllegalArgumentException e) {
return false;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
import lombok.Builder;
import lombok.Getter;
import lombok.Setter;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.CommonValue;
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
import org.opensearch.ml.common.annotation.InputDataSet;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.dataset.MLInputDataset;
Expand All @@ -20,32 +23,56 @@
@Getter
@InputDataSet(MLInputDataType.REMOTE)
public class RemoteInferenceInputDataSet extends MLInputDataset {

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG = CommonValue.VERSION_2_16_0;
@Setter
private Map<String, String> parameters;
@Setter
private ActionType actionType;

@Builder(toBuilder = true)
public RemoteInferenceInputDataSet(Map<String, String> parameters) {
public RemoteInferenceInputDataSet(Map<String, String> parameters, ActionType actionType) {
super(MLInputDataType.REMOTE);
this.parameters = parameters;
this.actionType = actionType;
}

public RemoteInferenceInputDataSet(Map<String, String> parameters) {
this(parameters, null);
}

public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException {
super(MLInputDataType.REMOTE);
Version streamInputVersion = streamInput.getVersion();
if (streamInput.readBoolean()) {
parameters = streamInput.readMap(s -> s.readString(), s-> s.readString());
}
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG)) {
if (streamInput.readBoolean()) {
actionType = streamInput.readEnum(ActionType.class);
} else {
this.actionType = null;
}
}
}

@Override
public void writeTo(StreamOutput streamOutput) throws IOException {
super.writeTo(streamOutput);
Version streamOutputVersion = streamOutput.getVersion();
if (parameters != null) {
streamOutput.writeBoolean(true);
streamOutput.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeString);
} else {
streamOutput.writeBoolean(false);
}
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG)) {
if (actionType != null) {
streamOutput.writeBoolean(true);
streamOutput.writeEnum(actionType);
} else {
streamOutput.writeBoolean(false);
}
}
}

}
14 changes: 14 additions & 0 deletions common/src/main/java/org/opensearch/ml/common/input/MLInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
import org.opensearch.ml.common.MLCommonsClassLoader;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataframe.DefaultDataFrame;
Expand All @@ -35,6 +36,7 @@
import java.util.Map;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.input.remote.RemoteInferenceMLInput.ACTION_TYPE_FIELD;

/**
* ML input data: algorithm name, parameters and input data set.
Expand Down Expand Up @@ -196,6 +198,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
RemoteInferenceInputDataSet remoteInferenceInputDataSet = (RemoteInferenceInputDataSet) this.inputDataset;
Map<String, String> parameters = remoteInferenceInputDataSet.getParameters();
builder.field(PARAMETERS_FIELD, parameters);
builder.field(ACTION_TYPE_FIELD, remoteInferenceInputDataSet.getActionType());
break;
default:
break;
Expand All @@ -206,6 +209,17 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
return builder;
}

public static MLInput parse(XContentParser parser, String inputAlgoName, ActionType actionType) throws IOException {
MLInput mlInput = parse(parser, inputAlgoName);
if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
RemoteInferenceInputDataSet remoteInferenceInputDataSet = (RemoteInferenceInputDataSet)mlInput.getInputDataset();
if (remoteInferenceInputDataSet.getActionType() == null) {
remoteInferenceInputDataSet.setActionType(actionType);
}
}
return mlInput;
}

public static MLInput parse(XContentParser parser, String inputAlgoName) throws IOException {
String algorithmName = inputAlgoName.toUpperCase(Locale.ROOT);
FunctionName algorithm = FunctionName.from(algorithmName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.utils.StringUtils;
Expand All @@ -21,6 +22,7 @@
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.REMOTE})
public class RemoteInferenceMLInput extends MLInput {
public static final String PARAMETERS_FIELD = "parameters";
public static final String ACTION_TYPE_FIELD = "action_type";

public RemoteInferenceMLInput(StreamInput in) throws IOException {
super(in);
Expand All @@ -34,21 +36,26 @@ public void writeTo(StreamOutput out) throws IOException {
public RemoteInferenceMLInput(XContentParser parser, FunctionName functionName) throws IOException {
super();
this.algorithm = functionName;
Map<String, String> parameters = null;
ActionType actionType = null;
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case PARAMETERS_FIELD:
Map<String, String> parameters = StringUtils.getParameterMap(parser.map());
inputDataset = new RemoteInferenceInputDataSet(parameters);
parameters = StringUtils.getParameterMap(parser.map());
break;
case ACTION_TYPE_FIELD:
actionType = ActionType.from(parser.text());
break;
default:
parser.skipChildren();
break;
}
}
inputDataset = new RemoteInferenceInputDataSet(parameters, actionType);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import java.util.HashMap;
import java.util.Map;

import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.isValidActionInModelPrediction;

public class ConnectorActionTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();
Expand Down Expand Up @@ -140,4 +142,17 @@ public void parse() throws IOException {
Assert.assertEquals("connector.pre_process.openai.embedding", action.getPreProcessFunction());
Assert.assertEquals("connector.post_process.openai.embedding", action.getPostProcessFunction());
}

@Test
public void test_wrongActionType() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Wrong Action Type");
ConnectorAction.ActionType.from("badAction");
}

@Test
public void test_invalidActionInModelPrediction() {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.from("execute");
Assert.assertEquals(isValidActionInModelPrediction(actionType), false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import org.junit.Test;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
import org.opensearch.ml.common.dataset.MLInputDataset;

import java.io.IOException;
Expand Down Expand Up @@ -45,4 +46,24 @@ public void writeTo() throws IOException {
Assert.assertEquals("test value1", inputDataSet2.getParameters().get("key1"));
Assert.assertEquals("test value2", inputDataSet2.getParameters().get("key2"));
}

@Test
public void writeTo_withActionType() throws IOException {
Map<String, String> parameters = new HashMap<>();
parameters.put("key1", "test value1");
parameters.put("key2", "test value2");
ActionType actionType = ActionType.from("predict");
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).actionType(actionType).build();

BytesStreamOutput output = new BytesStreamOutput();
inputDataSet.writeTo(output);
StreamInput streamInput = output.bytes().streamInput();

RemoteInferenceInputDataSet inputDataSet2 = (RemoteInferenceInputDataSet) MLInputDataset.fromStream(streamInput);
Assert.assertEquals(REMOTE, inputDataSet2.getInputDataType());
Assert.assertEquals(2, inputDataSet2.getParameters().size());
Assert.assertEquals("test value1", inputDataSet2.getParameters().get("key1"));
Assert.assertEquals("test value2", inputDataSet2.getParameters().get("key2"));
Assert.assertEquals("PREDICT", inputDataSet2.getActionType().toString());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.dataframe.*;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.nlp.TextSimilarityMLInput;
import org.opensearch.ml.common.input.parameter.regression.LinearRegressionParams;
import org.opensearch.ml.common.output.model.ModelResultFilter;
Expand All @@ -37,7 +39,9 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;

Expand Down Expand Up @@ -160,6 +164,40 @@ public void parse_NLPRelated_NullResultFilter() throws IOException {
parse_NLPModel_NullResultFilter(FunctionName.SPARSE_ENCODING);
}

@Test
public void parse_Remote_Model() throws IOException {
Map<String, String> parameters = Map.of("TransformJobName", "new name");
RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder()
.parameters(parameters)
.actionType(ConnectorAction.ActionType.PREDICT)
.build();

String expectedInputStr = "{\"algorithm\":\"REMOTE\",\"parameters\":{\"TransformJobName\":\"new name\"},\"action_type\":\"PREDICT\"}";

testParse(FunctionName.REMOTE, remoteInferenceInputDataSet, expectedInputStr, parsedInput -> {
assertNotNull(parsedInput.getInputDataset());
RemoteInferenceInputDataSet parsedInputDataSet = (RemoteInferenceInputDataSet) parsedInput.getInputDataset();
assertEquals(ConnectorAction.ActionType.PREDICT, parsedInputDataSet.getActionType());
});
}

@Test
public void parse_Remote_Model_With_ActionType() throws IOException {
Map<String, String> parameters = Map.of("TransformJobName", "new name");
RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder()
.parameters(parameters)
.actionType(ConnectorAction.ActionType.BATCH_PREDICT)
.build();

String expectedInputStr = "{\"algorithm\":\"REMOTE\",\"parameters\":{\"TransformJobName\":\"new name\"},\"action_type\":\"BATCH_PREDICT\"}";

testParseWithActionType(FunctionName.REMOTE, remoteInferenceInputDataSet, ConnectorAction.ActionType.BATCH_PREDICT, expectedInputStr, parsedInput -> {
assertNotNull(parsedInput.getInputDataset());
RemoteInferenceInputDataSet parsedInputDataSet = (RemoteInferenceInputDataSet) parsedInput.getInputDataset();
assertEquals(ConnectorAction.ActionType.BATCH_PREDICT, parsedInputDataSet.getActionType());
});
}

private void testParse(FunctionName algorithm, MLInputDataset inputDataset, String expectedInputStr, Consumer<MLInput> verify) throws IOException {
MLInput input = MLInput.builder().inputDataset(inputDataset).algorithm(algorithm).build();
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
Expand All @@ -178,6 +216,24 @@ private void testParse(FunctionName algorithm, MLInputDataset inputDataset, Stri
verify.accept(parsedInput);
}

private void testParseWithActionType(FunctionName algorithm, MLInputDataset inputDataset, ConnectorAction.ActionType actionType, String expectedInputStr, Consumer<MLInput> verify) throws IOException {
MLInput input = MLInput.builder().inputDataset(inputDataset).algorithm(algorithm).build();
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
input.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertNotNull(builder);
String jsonStr = builder.toString();
assertEquals(expectedInputStr, jsonStr);

XContentParser parser = XContentType.JSON.xContent()
.createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();
MLInput parsedInput = MLInput.parse(parser, algorithm.name(), actionType);
assertEquals(input.getFunctionName(), parsedInput.getFunctionName());
assertEquals(input.getInputDataset().getInputDataType(), parsedInput.getInputDataset().getInputDataType());
verify.accept(parsedInput);
}

@Test
public void readInputStream_Success() throws IOException {
readInputStream(input, parsedInput -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ public void constructor_stream() throws IOException {
RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet)input.getInputDataset();
Assert.assertEquals(1, inputDataSet.getParameters().size());
Assert.assertEquals("hello world", inputDataSet.getParameters().get("prompt"));
Assert.assertEquals("BATCH_PREDICT", inputDataSet.getActionType().toString());
}

private static RemoteInferenceMLInput createRemoteInferenceMLInput() throws IOException {
String jsonStr = "{ \"parameters\": { \"prompt\": \"hello world\" } }";
String jsonStr = "{ \"parameters\": { \"prompt\": \"hello world\" }, \"action_type\": \"batch_predict\" }";
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLGuard;
Expand Down Expand Up @@ -70,7 +72,12 @@ public void asyncPredict(MLInput mlInput, ActionListener<MLTaskResponse> actionL
return;
}
try {
connectorExecutor.executeAction(PREDICT.name(), mlInput, actionListener);
ActionType actionType = null;
if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
actionType = ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getActionType();
}
actionType = actionType == null ? ActionType.PREDICT : actionType;
connectorExecutor.executeAction(actionType.toString(), mlInput, actionListener);
} catch (RuntimeException e) {
log.error("Failed to call remote model.", e);
actionListener.onFailure(e);
Expand Down
Loading

0 comments on commit 14c2363

Please sign in to comment.