Skip to content

Commit

Permalink
add bedrock batch job post process function; enhance remote job statu…
Browse files Browse the repository at this point in the history
…s parsing (#2955)

Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn authored Sep 16, 2024
1 parent 091f5df commit 0d26931
Show file tree
Hide file tree
Showing 12 changed files with 344 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,7 @@ public enum MLTaskState {
COMPLETED,
FAILED,
CANCELLED,
COMPLETED_WITH_ERROR
COMPLETED_WITH_ERROR,
CANCELLING,
EXPIRED
}
4 changes: 2 additions & 2 deletions common/src/main/java/org/opensearch/ml/common/MLTaskType.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
public enum MLTaskType {
TRAINING,
PREDICTION,
BATCH_PREDICTION,
TRAINING_AND_PREDICTION,
EXECUTION,
@Deprecated
Expand All @@ -17,5 +16,6 @@ public enum MLTaskType {
LOAD_MODEL,
REGISTER_MODEL,
DEPLOY_MODEL,
BATCH_INGEST
BATCH_INGEST,
BATCH_PREDICTION
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.Map;
import java.util.function.Function;

import org.opensearch.ml.common.connector.functions.postprocess.BedrockBatchJobArnPostProcessFunction;
import org.opensearch.ml.common.connector.functions.postprocess.BedrockEmbeddingPostProcessFunction;
import org.opensearch.ml.common.connector.functions.postprocess.CohereRerankPostProcessFunction;
import org.opensearch.ml.common.connector.functions.postprocess.EmbeddingPostProcessFunction;
Expand All @@ -20,6 +21,7 @@ public class MLPostProcessFunction {
public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding";
public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding";
public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding";
public static final String BEDROCK_BATCH_JOB_ARN = "connector.post_process.bedrock.batch_job_arn";
public static final String COHERE_RERANK = "connector.post_process.cohere.rerank";
public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding";
public static final String DEFAULT_RERANK = "connector.post_process.default.rerank";
Expand All @@ -31,17 +33,20 @@ public class MLPostProcessFunction {
static {
EmbeddingPostProcessFunction embeddingPostProcessFunction = new EmbeddingPostProcessFunction();
BedrockEmbeddingPostProcessFunction bedrockEmbeddingPostProcessFunction = new BedrockEmbeddingPostProcessFunction();
BedrockBatchJobArnPostProcessFunction batchJobArnPostProcessFunction = new BedrockBatchJobArnPostProcessFunction();
CohereRerankPostProcessFunction cohereRerankPostProcessFunction = new CohereRerankPostProcessFunction();
JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding");
JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings");
JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]");
JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding");
JSON_PATH_EXPRESSION.put(BEDROCK_BATCH_JOB_ARN, "$");
JSON_PATH_EXPRESSION.put(COHERE_RERANK, "$.results");
JSON_PATH_EXPRESSION.put(DEFAULT_RERANK, "$[*]");
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, bedrockEmbeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(BEDROCK_BATCH_JOB_ARN, batchJobArnPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_RERANK, cohereRerankPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(DEFAULT_RERANK, cohereRerankPostProcessFunction);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.postprocess;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.opensearch.ml.common.output.model.ModelTensor;

public class BedrockBatchJobArnPostProcessFunction extends ConnectorPostProcessFunction<Map<String, String>> {
public static final String JOB_ARN = "jobArn";
public static final String PROCESSED_JOB_ARN = "processedJobArn";

@Override
public void validate(Object input) {
if (!(input instanceof Map)) {
throw new IllegalArgumentException("Post process function input is not a Map.");
}
Map<String, String> jobInfo = (Map<String, String>) input;
if (!(jobInfo.containsKey(JOB_ARN))) {
throw new IllegalArgumentException("job arn is missing.");
}
}

@Override
public List<ModelTensor> process(Map<String, String> jobInfo) {
List<ModelTensor> modelTensors = new ArrayList<>();
Map<String, String> processedResult = new HashMap<>();
processedResult.putAll(jobInfo);
String jobArn = jobInfo.get(JOB_ARN);
processedResult.put(PROCESSED_JOB_ARN, jobArn.replace("/", "%2F"));
modelTensors.add(ModelTensor.builder().name("response").dataAsMap(processedResult).build());
return modelTensors;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

public class MLCancelBatchJobAction extends ActionType<MLCancelBatchJobResponse> {
public static final MLCancelBatchJobAction INSTANCE = new MLCancelBatchJobAction();
public static final String NAME = "cluster:admin/opensearch/ml/tasks/cancel_batch_job";
public static final String NAME = "cluster:admin/opensearch/ml/tasks/cancel";

private MLCancelBatchJobAction() {
super(NAME, MLCancelBatchJobResponse::new);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.postprocess;

import static org.junit.Assert.assertEquals;
import static org.opensearch.ml.common.connector.functions.postprocess.BedrockBatchJobArnPostProcessFunction.JOB_ARN;
import static org.opensearch.ml.common.connector.functions.postprocess.BedrockBatchJobArnPostProcessFunction.PROCESSED_JOB_ARN;

import java.util.List;
import java.util.Map;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.ml.common.output.model.ModelTensor;

public class BedrockBatchJobArnPostProcessFunctionTest {

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

BedrockBatchJobArnPostProcessFunction function;

@Before
public void setUp() {
function = new BedrockBatchJobArnPostProcessFunction();
}

@Test
public void process_WrongInput_NotMap() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Post process function input is not a Map.");
function.apply("abc");
}

@Test
public void process_WrongInput_NotContainJobArn() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("job arn is missing.");
function.apply(Map.of("test", "value"));
}

@Test
public void process_CorrectInput() {
String jobArn = "arn:aws:bedrock:us-east-1:12345678912:model-invocation-job/w1xtlm0ik3e1";
List<ModelTensor> result = function.apply(Map.of(JOB_ARN, jobArn));
assertEquals(1, result.size());
assertEquals(jobArn, result.get(0).getDataAsMap().get(JOB_ARN));
assertEquals(
"arn:aws:bedrock:us-east-1:12345678912:model-invocation-job%2Fw1xtlm0ik3e1",
result.get(0).getDataAsMap().get(PROCESSED_JOB_ARN)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,25 @@
import static org.opensearch.ml.common.MLTask.REMOTE_JOB_FIELD;
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
import static org.opensearch.ml.common.MLTaskState.CANCELLED;
import static org.opensearch.ml.common.MLTaskState.CANCELLING;
import static org.opensearch.ml.common.MLTaskState.COMPLETED;
import static org.opensearch.ml.common.MLTaskState.EXPIRED;
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.BATCH_PREDICT_STATUS;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_FIELD;
import static org.opensearch.ml.utils.MLExceptionUtils.logException;
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.opensearch.OpenSearchException;
import org.opensearch.OpenSearchStatusException;
Expand All @@ -30,6 +41,8 @@
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
Expand Down Expand Up @@ -80,6 +93,12 @@ public class GetTaskTransportAction extends HandledTransportAction<ActionRequest
MLTaskManager mlTaskManager;
MLModelCacheHelper modelCacheHelper;

volatile List<String> remoteJobStatusFields;
volatile Pattern remoteJobCompletedStatusRegexPattern;
volatile Pattern remoteJobCancelledStatusRegexPattern;
volatile Pattern remoteJobCancellingStatusRegexPattern;
volatile Pattern remoteJobExpiredStatusRegexPattern;

@Inject
public GetTaskTransportAction(
TransportService transportService,
Expand All @@ -91,7 +110,8 @@ public GetTaskTransportAction(
ConnectorAccessControlHelper connectorAccessControlHelper,
EncryptorImpl encryptor,
MLTaskManager mlTaskManager,
MLModelManager mlModelManager
MLModelManager mlModelManager,
Settings settings
) {
super(MLTaskGetAction.NAME, transportService, actionFilters, MLTaskGetRequest::new);
this.client = client;
Expand All @@ -102,6 +122,44 @@ public GetTaskTransportAction(
this.encryptor = encryptor;
this.mlTaskManager = mlTaskManager;
this.mlModelManager = mlModelManager;

remoteJobStatusFields = ML_COMMONS_REMOTE_JOB_STATUS_FIELD.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_REMOTE_JOB_STATUS_FIELD, it -> remoteJobStatusFields = it);
initializeRegexPattern(
ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX,
settings,
clusterService,
(regex) -> remoteJobCompletedStatusRegexPattern = Pattern.compile(regex, Pattern.CASE_INSENSITIVE)
);
initializeRegexPattern(
ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX,
settings,
clusterService,
(regex) -> remoteJobCancelledStatusRegexPattern = Pattern.compile(regex, Pattern.CASE_INSENSITIVE)
);
initializeRegexPattern(
ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX,
settings,
clusterService,
(regex) -> remoteJobCancellingStatusRegexPattern = Pattern.compile(regex, Pattern.CASE_INSENSITIVE)
);
initializeRegexPattern(
ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX,
settings,
clusterService,
(regex) -> remoteJobExpiredStatusRegexPattern = Pattern.compile(regex, Pattern.CASE_INSENSITIVE)
);
}

private void initializeRegexPattern(
Setting<String> setting,
Settings settings,
ClusterService clusterService,
Consumer<String> patternInitializer
) {
String regex = setting.get(settings);
patternInitializer.accept(regex);
clusterService.getClusterSettings().addSettingsUpdateConsumer(setting, it -> patternInitializer.accept(it));
}

@Override
Expand Down Expand Up @@ -210,7 +268,7 @@ private void executeConnector(
MLInput mlInput,
String taskId,
MLTask mlTask,
Map<String, Object> transformJob,
Map<String, Object> remoteJob,
ActionListener<MLTaskGetResponse> actionListener
) {
if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) {
Expand All @@ -222,15 +280,15 @@ private void executeConnector(
connectorExecutor.setClient(client);
connectorExecutor.setXContentRegistry(xContentRegistry);
connectorExecutor.executeAction(BATCH_PREDICT_STATUS.name(), mlInput, ActionListener.wrap(taskResponse -> {
processTaskResponse(mlTask, taskId, taskResponse, transformJob, actionListener);
processTaskResponse(mlTask, taskId, taskResponse, remoteJob, actionListener);
}, e -> { actionListener.onFailure(e); }));
} else {
actionListener
.onFailure(new OpenSearchStatusException("You don't have permission to access this connector", RestStatus.FORBIDDEN));
}
}

private void processTaskResponse(
protected void processTaskResponse(
MLTask mlTask,
String taskId,
MLTaskResponse taskResponse,
Expand All @@ -248,15 +306,11 @@ private void processTaskResponse(
Map<String, Object> updatedTask = new HashMap<>();
updatedTask.put(REMOTE_JOB_FIELD, remoteJob);

if ((remoteJob.containsKey("status") && remoteJob.get("status").equals("completed"))
|| (remoteJob.containsKey("TransformJobStatus") && remoteJob.get("TransformJobStatus").equals("Completed"))) {
updatedTask.put(STATE_FIELD, COMPLETED);
mlTask.setState(COMPLETED);

} else if ((remoteJob.containsKey("status") && remoteJob.get("status").equals("cancelled"))
|| (remoteJob.containsKey("TransformJobStatus") && remoteJob.get("TransformJobStatus").equals("Stopped"))) {
updatedTask.put(STATE_FIELD, CANCELLED);
mlTask.setState(CANCELLED);
for (String statusField : remoteJobStatusFields) {
String statusValue = String.valueOf(remoteJob.get(statusField));
if (remoteJob.containsKey(statusField)) {
updateTaskState(updatedTask, mlTask, statusValue);
}
}
mlTaskManager.updateMLTaskDirectly(taskId, updatedTask, ActionListener.wrap(response -> {
actionListener.onResponse(MLTaskGetResponse.builder().mlTask(mlTask).build());
Expand All @@ -280,4 +334,25 @@ private void processTaskResponse(
log.error("Unable to fetch status for ml task ", e);
}
}

private void updateTaskState(Map<String, Object> updatedTask, MLTask mlTask, String statusValue) {
if (matchesPattern(remoteJobCancellingStatusRegexPattern, statusValue)) {
updatedTask.put(STATE_FIELD, CANCELLING);
mlTask.setState(CANCELLING);
} else if (matchesPattern(remoteJobCancelledStatusRegexPattern, statusValue)) {
updatedTask.put(STATE_FIELD, CANCELLED);
mlTask.setState(CANCELLED);
} else if (matchesPattern(remoteJobCompletedStatusRegexPattern, statusValue)) {
updatedTask.put(STATE_FIELD, COMPLETED);
mlTask.setState(COMPLETED);
} else if (matchesPattern(remoteJobExpiredStatusRegexPattern, statusValue)) {
updatedTask.put(STATE_FIELD, EXPIRED);
mlTask.setState(EXPIRED);
}
}

private boolean matchesPattern(Pattern pattern, String input) {
Matcher matcher = pattern.matcher(input);
return matcher.find();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,12 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED,
MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED,
MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE,
MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED
MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED,
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_FIELD,
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX,
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX,
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX,
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX
);
return settings;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;

//TODO: Rename class and support cancelling more tasks. Now only support cancelling remote job
public class RestMLCancelBatchJobAction extends BaseRestHandler {
private static final String ML_CANCEL_BATCH_ACTION = "ml_cancel_batch_action";
private static final String ML_CANCEL_TASK_ACTION = "ml_cancel_task_action";

/**
* Constructor
Expand All @@ -33,18 +34,13 @@ public RestMLCancelBatchJobAction() {}

@Override
public String getName() {
return ML_CANCEL_BATCH_ACTION;
return ML_CANCEL_TASK_ACTION;
}

@Override
public List<Route> routes() {
return ImmutableList
.of(
new Route(
RestRequest.Method.POST,
String.format(Locale.ROOT, "%s/tasks/{%s}/_cancel_batch", ML_BASE_URI, PARAMETER_TASK_ID)
)
);
.of(new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/tasks/{%s}/_cancel", ML_BASE_URI, PARAMETER_TASK_ID)));
}

@Override
Expand Down
Loading

0 comments on commit 0d26931

Please sign in to comment.