Skip to content

Commit

Permalink
add task to batch predict request
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Ramaram <[email protected]>
  • Loading branch information
rbhavna committed Jul 22, 2024
1 parent 8d4aefa commit 02f1107
Show file tree
Hide file tree
Showing 8 changed files with 428 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public class CommonValue {
public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2;
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 11;
public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector";
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2;
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 3;
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 3;
public static final String ML_CONFIG_INDEX = ".plugins-ml-config";
public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 2;
Expand Down Expand Up @@ -359,6 +359,10 @@ public class CommonValue {
+ "\" : {\"type\" : \"boolean\"}, \n"
+ USER_FIELD_MAPPING
+ " }\n"
+ "}"
+ MLTask.TRANSFORM_JOB_FIELD
+ "\" : {\"type\": \"flat_object\"}\n"
+ " }\n"
+ "}";

public static final String ML_CONNECTOR_INDEX_MAPPING = "{\n"
Expand Down Expand Up @@ -537,4 +541,5 @@ public class CommonValue {
public static final Version VERSION_2_12_0 = Version.fromString("2.12.0");
public static final Version VERSION_2_13_0 = Version.fromString("2.13.0");
public static final Version VERSION_2_14_0 = Version.fromString("2.14.0");
public static final Version VERSION_2_16_0 = Version.fromString("2.16.0");
}
47 changes: 46 additions & 1 deletion common/src/main/java/org/opensearch/ml/common/MLTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
Expand All @@ -17,15 +18,22 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.USER;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
import static org.opensearch.ml.common.utils.StringUtils.gson;

@Getter
@EqualsAndHashCode
Expand All @@ -44,6 +52,8 @@ public class MLTask implements ToXContentObject, Writeable {
public static final String LAST_UPDATE_TIME_FIELD = "last_update_time";
public static final String ERROR_FIELD = "error";
public static final String IS_ASYNC_TASK_FIELD = "is_async";
public static final String TRANSFORM_JOB_FIELD = "transform_job";
public static final Version MINIMAL_SUPPORTED_VERSION_FOR_BATCH_TRANSFORM_JOB = CommonValue.VERSION_2_16_0;

@Setter
private String taskId;
Expand All @@ -65,6 +75,8 @@ public class MLTask implements ToXContentObject, Writeable {
private String error;
private User user; // TODO: support document level access control later
private boolean async;
@Setter
private Map<String, Object> transformJob;

@Builder(toBuilder = true)
public MLTask(
Expand All @@ -81,7 +93,8 @@ public MLTask(
Instant lastUpdateTime,
String error,
User user,
boolean async
boolean async,
Map<String, Object> transformJob
) {
this.taskId = taskId;
this.modelId = modelId;
Expand All @@ -97,9 +110,11 @@ public MLTask(
this.error = error;
this.user = user;
this.async = async;
this.transformJob = transformJob;
}

public MLTask(StreamInput input) throws IOException {
Version streamInputVersion = input.getVersion();
this.taskId = input.readOptionalString();
this.modelId = input.readOptionalString();
this.taskType = input.readEnum(MLTaskType.class);
Expand All @@ -122,10 +137,17 @@ public MLTask(StreamInput input) throws IOException {
this.user = null;
}
this.async = input.readBoolean();
if (streamInputVersion.onOrAfter(MLTask.MINIMAL_SUPPORTED_VERSION_FOR_BATCH_TRANSFORM_JOB)) {
if (input.readBoolean()) {
String mapStr = input.readString();
this.transformJob = gson.fromJson(mapStr, Map.class);
}
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
Version streamOutputVersion = out.getVersion();
out.writeOptionalString(taskId);
out.writeOptionalString(modelId);
out.writeEnum(taskType);
Expand All @@ -149,6 +171,21 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
}
out.writeBoolean(async);
if (streamOutputVersion.onOrAfter(MLTask.MINIMAL_SUPPORTED_VERSION_FOR_BATCH_TRANSFORM_JOB)) {
if (transformJob != null) {
out.writeBoolean(true);
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
out.writeString(gson.toJson(transformJob));
return null;
});
} catch (PrivilegedActionException e) {
throw new RuntimeException(e);
}
} else {
out.writeBoolean(false);
}
}
}

@Override
Expand Down Expand Up @@ -194,6 +231,9 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
builder.field(USER, user);
}
builder.field(IS_ASYNC_TASK_FIELD, async);
if (transformJob != null) {
builder.field(TRANSFORM_JOB_FIELD, transformJob);
}
return builder.endObject();
}

Expand All @@ -217,6 +257,7 @@ public static MLTask parse(XContentParser parser) throws IOException {
String error = null;
User user = null;
boolean async = false;
Map<String, Object> transformJob = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -274,6 +315,9 @@ public static MLTask parse(XContentParser parser) throws IOException {
case IS_ASYNC_TASK_FIELD:
async = parser.booleanValue();
break;
case TRANSFORM_JOB_FIELD:
transformJob = parser.map();
break;
default:
parser.skipChildren();
break;
Expand All @@ -294,6 +338,7 @@ public static MLTask parse(XContentParser parser) throws IOException {
.error(error)
.user(user)
.async(async)
.transformJob(transformJob)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
public enum MLTaskType {
TRAINING,
PREDICTION,
BATCH_PREDICTION,
TRAINING_AND_PREDICTION,
EXECUTION,
@Deprecated
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ public static ConnectorAction parse(XContentParser parser) throws IOException {
public enum ActionType {
PREDICT,
EXECUTE,
BATCH
BATCH,
CANCEL_BATCH,
BATCH_STATUS
}
}
Loading

0 comments on commit 02f1107

Please sign in to comment.