Skip to content

Commit

Permalink
Add Batch Prediction Mode in the Connector Framework for batch infere…
Browse files Browse the repository at this point in the history
…nce (#2661) (#2701)

* add batch predict job actiontype in connector



* remove async and streaming mode temporarily



* rename predict mode to action type



* use method name in the url path for action type



* add stats for actionType and more UTs



* add bwx for actiontype



* address more comments



---------

Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt authored Jul 22, 2024
1 parent 9b2e5f1 commit 310d023
Show file tree
Hide file tree
Showing 19 changed files with 314 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -538,4 +538,5 @@ public class CommonValue {
public static final Version VERSION_2_12_0 = Version.fromString("2.12.0");
public static final Version VERSION_2_13_0 = Version.fromString("2.13.0");
public static final Version VERSION_2_14_0 = Version.fromString("2.14.0");
public static final Version VERSION_2_16_0 = Version.fromString("2.16.0");
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;

import java.io.IOException;
import java.util.HashSet;
import java.util.Locale;
import java.util.Map;
import java.util.Set;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

Expand Down Expand Up @@ -183,6 +186,33 @@ public static ConnectorAction parse(XContentParser parser) throws IOException {

public enum ActionType {
PREDICT,
EXECUTE
EXECUTE,
BATCH_PREDICT;

public static ActionType from(String value) {
try {
return ActionType.valueOf(value.toUpperCase(Locale.ROOT));
} catch (Exception e) {
throw new IllegalArgumentException("Wrong Action Type of " + value);
}
}

private static final HashSet<ActionType> MODEL_SUPPORT_ACTIONS = new HashSet<>(Set.of(
PREDICT,
BATCH_PREDICT
));

public static boolean isValidActionInModelPrediction(ActionType actionType) {
return MODEL_SUPPORT_ACTIONS.contains(actionType);
}

public static boolean isValidAction(String action) {
try {
ActionType.valueOf(action.toUpperCase());
return true;
} catch (IllegalArgumentException e) {
return false;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,41 +10,68 @@
import lombok.Builder;
import lombok.Getter;
import lombok.Setter;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.CommonValue;
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
import org.opensearch.ml.common.annotation.InputDataSet;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.dataset.MLInputDataset;

@Getter
@InputDataSet(MLInputDataType.REMOTE)
public class RemoteInferenceInputDataSet extends MLInputDataset {

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG = CommonValue.VERSION_2_16_0;
@Setter
private Map<String, String> parameters;
@Setter
private ActionType actionType;

@Builder(toBuilder = true)
public RemoteInferenceInputDataSet(Map<String, String> parameters) {
public RemoteInferenceInputDataSet(Map<String, String> parameters, ActionType actionType) {
super(MLInputDataType.REMOTE);
this.parameters = parameters;
this.actionType = actionType;
}

public RemoteInferenceInputDataSet(Map<String, String> parameters) {
this(parameters, null);
}

public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException {
super(MLInputDataType.REMOTE);
Version streamInputVersion = streamInput.getVersion();
if (streamInput.readBoolean()) {
parameters = streamInput.readMap(s -> s.readString(), s-> s.readString());
}
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG)) {
if (streamInput.readBoolean()) {
actionType = streamInput.readEnum(ActionType.class);
} else {
this.actionType = null;
}
}
}

@Override
public void writeTo(StreamOutput streamOutput) throws IOException {
super.writeTo(streamOutput);
Version streamOutputVersion = streamOutput.getVersion();
if (parameters != null) {
streamOutput.writeBoolean(true);
streamOutput.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeString);
} else {
streamOutput.writeBoolean(false);
}
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG)) {
if (actionType != null) {
streamOutput.writeBoolean(true);
streamOutput.writeEnum(actionType);
} else {
streamOutput.writeBoolean(false);
}
}
}

}
14 changes: 14 additions & 0 deletions common/src/main/java/org/opensearch/ml/common/input/MLInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
import org.opensearch.ml.common.MLCommonsClassLoader;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataframe.DefaultDataFrame;
Expand All @@ -35,6 +36,7 @@
import java.util.Map;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.input.remote.RemoteInferenceMLInput.ACTION_TYPE_FIELD;

/**
* ML input data: algorithm name, parameters and input data set.
Expand Down Expand Up @@ -196,6 +198,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
RemoteInferenceInputDataSet remoteInferenceInputDataSet = (RemoteInferenceInputDataSet) this.inputDataset;
Map<String, String> parameters = remoteInferenceInputDataSet.getParameters();
builder.field(PARAMETERS_FIELD, parameters);
builder.field(ACTION_TYPE_FIELD, remoteInferenceInputDataSet.getActionType());
break;
default:
break;
Expand All @@ -206,6 +209,17 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
return builder;
}

public static MLInput parse(XContentParser parser, String inputAlgoName, ActionType actionType) throws IOException {
MLInput mlInput = parse(parser, inputAlgoName);
if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
RemoteInferenceInputDataSet remoteInferenceInputDataSet = (RemoteInferenceInputDataSet)mlInput.getInputDataset();
if (remoteInferenceInputDataSet.getActionType() == null) {
remoteInferenceInputDataSet.setActionType(actionType);
}
}
return mlInput;
}

public static MLInput parse(XContentParser parser, String inputAlgoName) throws IOException {
String algorithmName = inputAlgoName.toUpperCase(Locale.ROOT);
FunctionName algorithm = FunctionName.from(algorithmName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.utils.StringUtils;
Expand All @@ -21,6 +22,7 @@
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.REMOTE})
public class RemoteInferenceMLInput extends MLInput {
public static final String PARAMETERS_FIELD = "parameters";
public static final String ACTION_TYPE_FIELD = "action_type";

public RemoteInferenceMLInput(StreamInput in) throws IOException {
super(in);
Expand All @@ -34,21 +36,26 @@ public void writeTo(StreamOutput out) throws IOException {
public RemoteInferenceMLInput(XContentParser parser, FunctionName functionName) throws IOException {
super();
this.algorithm = functionName;
Map<String, String> parameters = null;
ActionType actionType = null;
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case PARAMETERS_FIELD:
Map<String, String> parameters = StringUtils.getParameterMap(parser.map());
inputDataset = new RemoteInferenceInputDataSet(parameters);
parameters = StringUtils.getParameterMap(parser.map());
break;
case ACTION_TYPE_FIELD:
actionType = ActionType.from(parser.text());
break;
default:
parser.skipChildren();
break;
}
}
inputDataset = new RemoteInferenceInputDataSet(parameters, actionType);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import java.util.HashMap;
import java.util.Map;

import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.isValidActionInModelPrediction;

public class ConnectorActionTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();
Expand Down Expand Up @@ -140,4 +142,17 @@ public void parse() throws IOException {
Assert.assertEquals("connector.pre_process.openai.embedding", action.getPreProcessFunction());
Assert.assertEquals("connector.post_process.openai.embedding", action.getPostProcessFunction());
}

@Test
public void test_wrongActionType() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Wrong Action Type");
ConnectorAction.ActionType.from("badAction");
}

@Test
public void test_invalidActionInModelPrediction() {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.from("execute");
Assert.assertEquals(isValidActionInModelPrediction(actionType), false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.junit.Test;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
import org.opensearch.ml.common.dataset.MLInputDataset;

public class RemoteInferenceInputDataSetTest {
Expand Down Expand Up @@ -44,4 +45,24 @@ public void writeTo() throws IOException {
Assert.assertEquals("test value1", inputDataSet2.getParameters().get("key1"));
Assert.assertEquals("test value2", inputDataSet2.getParameters().get("key2"));
}

@Test
public void writeTo_withActionType() throws IOException {
Map<String, String> parameters = new HashMap<>();
parameters.put("key1", "test value1");
parameters.put("key2", "test value2");
ActionType actionType = ActionType.from("predict");
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).actionType(actionType).build();

BytesStreamOutput output = new BytesStreamOutput();
inputDataSet.writeTo(output);
StreamInput streamInput = output.bytes().streamInput();

RemoteInferenceInputDataSet inputDataSet2 = (RemoteInferenceInputDataSet) MLInputDataset.fromStream(streamInput);
Assert.assertEquals(REMOTE, inputDataSet2.getInputDataType());
Assert.assertEquals(2, inputDataSet2.getParameters().size());
Assert.assertEquals("test value1", inputDataSet2.getParameters().get("key1"));
Assert.assertEquals("test value2", inputDataSet2.getParameters().get("key2"));
Assert.assertEquals("PREDICT", inputDataSet2.getActionType().toString());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@
import org.opensearch.ml.common.dataframe.DefaultDataFrame;
import org.opensearch.ml.common.dataframe.DoubleValue;
import org.opensearch.ml.common.dataframe.Row;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.nlp.TextSimilarityMLInput;
import org.opensearch.ml.common.input.parameter.regression.LinearRegressionParams;
import org.opensearch.ml.common.output.model.ModelResultFilter;
Expand All @@ -44,6 +46,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;

Expand Down Expand Up @@ -168,6 +171,40 @@ public void parse_NLPRelated_NullResultFilter() throws IOException {
parse_NLPModel_NullResultFilter(FunctionName.SPARSE_ENCODING);
}

@Test
public void parse_Remote_Model() throws IOException {
Map<String, String> parameters = Map.of("TransformJobName", "new name");
RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder()
.parameters(parameters)
.actionType(ConnectorAction.ActionType.PREDICT)
.build();

String expectedInputStr = "{\"algorithm\":\"REMOTE\",\"parameters\":{\"TransformJobName\":\"new name\"},\"action_type\":\"PREDICT\"}";

testParse(FunctionName.REMOTE, remoteInferenceInputDataSet, expectedInputStr, parsedInput -> {
assertNotNull(parsedInput.getInputDataset());
RemoteInferenceInputDataSet parsedInputDataSet = (RemoteInferenceInputDataSet) parsedInput.getInputDataset();
assertEquals(ConnectorAction.ActionType.PREDICT, parsedInputDataSet.getActionType());
});
}

@Test
public void parse_Remote_Model_With_ActionType() throws IOException {
Map<String, String> parameters = Map.of("TransformJobName", "new name");
RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder()
.parameters(parameters)
.actionType(ConnectorAction.ActionType.BATCH_PREDICT)
.build();

String expectedInputStr = "{\"algorithm\":\"REMOTE\",\"parameters\":{\"TransformJobName\":\"new name\"},\"action_type\":\"BATCH_PREDICT\"}";

testParseWithActionType(FunctionName.REMOTE, remoteInferenceInputDataSet, ConnectorAction.ActionType.BATCH_PREDICT, expectedInputStr, parsedInput -> {
assertNotNull(parsedInput.getInputDataset());
RemoteInferenceInputDataSet parsedInputDataSet = (RemoteInferenceInputDataSet) parsedInput.getInputDataset();
assertEquals(ConnectorAction.ActionType.BATCH_PREDICT, parsedInputDataSet.getActionType());
});
}

private void testParse(FunctionName algorithm, MLInputDataset inputDataset, String expectedInputStr, Consumer<MLInput> verify) throws IOException {
MLInput input = MLInput.builder().inputDataset(inputDataset).algorithm(algorithm).build();
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
Expand All @@ -186,6 +223,24 @@ private void testParse(FunctionName algorithm, MLInputDataset inputDataset, Stri
verify.accept(parsedInput);
}

private void testParseWithActionType(FunctionName algorithm, MLInputDataset inputDataset, ConnectorAction.ActionType actionType, String expectedInputStr, Consumer<MLInput> verify) throws IOException {
MLInput input = MLInput.builder().inputDataset(inputDataset).algorithm(algorithm).build();
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
input.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertNotNull(builder);
String jsonStr = builder.toString();
assertEquals(expectedInputStr, jsonStr);

XContentParser parser = XContentType.JSON.xContent()
.createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();
MLInput parsedInput = MLInput.parse(parser, algorithm.name(), actionType);
assertEquals(input.getFunctionName(), parsedInput.getFunctionName());
assertEquals(input.getInputDataset().getInputDataType(), parsedInput.getInputDataset().getInputDataType());
verify.accept(parsedInput);
}

@Test
public void readInputStream_Success() throws IOException {
readInputStream(input, parsedInput -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ public void constructor_stream() throws IOException {
RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet)input.getInputDataset();
Assert.assertEquals(1, inputDataSet.getParameters().size());
Assert.assertEquals("hello world", inputDataSet.getParameters().get("prompt"));
Assert.assertEquals("BATCH_PREDICT", inputDataSet.getActionType().toString());
}

private static RemoteInferenceMLInput createRemoteInferenceMLInput() throws IOException {
String jsonStr = "{ \"parameters\": { \"prompt\": \"hello world\" } }";
String jsonStr = "{ \"parameters\": { \"prompt\": \"hello world\" }, \"action_type\": \"batch_predict\" }";
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,6 @@ public void testGetSg_NoIndex_ThenFail() {
interactionsIndex.getInteraction("iid", getListener);
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
verify(getListener, times(1)).onFailure(argCaptor.capture());
System.out.println(argCaptor.getValue().getMessage());
assert (argCaptor
.getValue()
.getMessage()
Expand Down
Loading

0 comments on commit 310d023

Please sign in to comment.