Skip to content

Commit

Permalink
fix field mapping, add more error handling and remove checking jobId … (
Browse files Browse the repository at this point in the history
#2933)

* fix field mapping, add more error handling and remove checking jobId filed in batch job response

Signed-off-by: Xun Zhang <[email protected]>

* add beckrock url in the allowed list and more UTs

Signed-off-by: Xun Zhang <[email protected]>

---------

Signed-off-by: Xun Zhang <[email protected]>
(cherry picked from commit 30228ea)
  • Loading branch information
Zhangxunmt authored and github-actions[bot] committed Sep 12, 2024
1 parent 0135cb9 commit 0b6252d
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,55 @@ protected double calculateSuccessRate(List<Double> successRates) {
);
}

/**
* Filters fields in the map where the value contains the specified source index as a prefix.
* When there is only one source file, users can skip the source[] prefix
*
* @param mlBatchIngestionInput The MLBatchIngestionInput.
* @return A new map of <fieldName: JsonPath> for all fields to be ingested.
*/
protected Map<String, Object> filterFieldMappingSoleSource(MLBatchIngestionInput mlBatchIngestionInput) {
Map<String, Object> fieldMap = mlBatchIngestionInput.getFieldMapping();
String prefix = "source[0]";

Map<String, Object> filteredFieldMap = fieldMap.entrySet().stream().filter(entry -> {
Object value = entry.getValue();
if (value instanceof String) {
String jsonPath = ((String) value);
return jsonPath.contains(prefix) || !jsonPath.startsWith("source");
} else if (value instanceof List) {
return ((List<String>) value).stream().anyMatch(val -> (val.contains(prefix) || !val.startsWith("source")));
}
return false;
}).collect(Collectors.toMap(Map.Entry::getKey, entry -> {
Object value = entry.getValue();
if (value instanceof String) {
return getJsonPath((String) value);
} else if (value instanceof List) {
return ((List<String>) value)
.stream()
.filter(val -> (val.contains(prefix) || !val.startsWith("source")))
.map(StringUtils::getJsonPath)
.collect(Collectors.toList());
}
return null;
}));

String[] ingestFields = mlBatchIngestionInput.getIngestFields();
if (ingestFields != null) {
Arrays
.stream(ingestFields)
.filter(Objects::nonNull)
.filter(val -> (val.contains(prefix) || !val.startsWith("source")))
.map(StringUtils::getJsonPath)
.forEach(jsonPath -> {
filteredFieldMap.put(obtainFieldNameFromJsonPath(jsonPath), jsonPath);
});
}

return filteredFieldMap;
}

/**
* Filters fields in the map where the value contains the specified source index as a prefix.
*
Expand Down Expand Up @@ -159,7 +208,7 @@ protected void batchIngest(
BulkRequest bulkRequest = new BulkRequest();
sourceLines.stream().forEach(jsonStr -> {
Map<String, Object> filteredMapping = isSoleSource
? mlBatchIngestionInput.getFieldMapping()
? filterFieldMappingSoleSource(mlBatchIngestionInput)
: filterFieldMapping(mlBatchIngestionInput, sourceIndex);
Map<String, Object> jsonMap = processFieldMapping(jsonStr, filteredMapping);
if (jsonMap.isEmpty()) {
Expand All @@ -174,7 +223,7 @@ protected void batchIngest(
if (!jsonMap.containsKey("_id")) {
throw new IllegalArgumentException("The id filed must be provided to match documents for multiple sources");
}
String id = (String) jsonMap.remove("_id");
String id = String.valueOf(jsonMap.remove("_id"));
UpdateRequest updateRequest = new UpdateRequest(mlBatchIngestionInput.getIndexName(), id).doc(jsonMap).upsert(jsonMap);
bulkRequest.add(updateRequest);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,35 @@ public void testFilterFieldMapping_MatchingPrefix() {
assertEquals(Arrays.asList("$.custom_id"), result.get("_id"));
}

@Test
public void testFilterFieldMappingSoleSource_MatchingPrefix() {
// Arrange
Map<String, Object> fieldMap = new HashMap<>();
fieldMap.put("question", "source[0].$.body.input[0]");
fieldMap.put("question_embedding", "source[0].$.response.body.data[0].embedding");
fieldMap.put("answer", "source[0].$.body.input[1]");
fieldMap.put("answer_embedding", "$.response.body.data[1].embedding");
fieldMap.put("_id", Arrays.asList("$.custom_id", "source[1].$.custom_id"));

MLBatchIngestionInput mlBatchIngestionInput = new MLBatchIngestionInput(
"indexName",
fieldMap,
ingestFields,
new HashMap<>(),
new HashMap<>()
);

// Act
Map<String, Object> result = s3DataIngestion.filterFieldMappingSoleSource(mlBatchIngestionInput);

// Assert
assertEquals(6, result.size());

assertEquals("$.body.input[0]", result.get("question"));
assertEquals("$.response.body.data[0].embedding", result.get("question_embedding"));
assertEquals(Arrays.asList("$.custom_id"), result.get("_id"));
}

@Test
public void testProcessFieldMapping_FromSM() {
String jsonStr =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
import static org.opensearch.ml.common.MLTaskState.COMPLETED;
import static org.opensearch.ml.common.MLTaskState.FAILED;
import static org.opensearch.ml.plugin.MachineLearningPlugin.TRAIN_THREAD_POOL;
import static org.opensearch.ml.plugin.MachineLearningPlugin.INGEST_THREAD_POOL;
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;

import java.time.Instant;
Expand Down Expand Up @@ -41,6 +41,8 @@
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

import com.jayway.jsonpath.PathNotFoundException;

import lombok.extern.log4j.Log4j2;

@Log4j2
Expand Down Expand Up @@ -92,9 +94,11 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLBatc
listener.onResponse(new MLBatchIngestionResponse(taskId, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name()));
String ingestType = (String) mlBatchIngestionInput.getDataSources().get(TYPE);
Ingestable ingestable = MLEngineClassLoader.initInstance(ingestType.toLowerCase(), client, Client.class);
threadPool.executor(TRAIN_THREAD_POOL).execute(() -> {
double successRate = ingestable.ingest(mlBatchIngestionInput);
handleSuccessRate(successRate, taskId);
threadPool.executor(INGEST_THREAD_POOL).execute(() -> {
executeWithErrorHandling(() -> {
double successRate = ingestable.ingest(mlBatchIngestionInput);
handleSuccessRate(successRate, taskId);
}, taskId);
});
} catch (Exception ex) {
log.error("Failed in batch ingestion", ex);
Expand Down Expand Up @@ -125,6 +129,30 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLBatc
}
}

protected void executeWithErrorHandling(Runnable task, String taskId) {
try {
task.run();
} catch (PathNotFoundException jsonPathNotFoundException) {
log.error("Error in jsonParse fields", jsonPathNotFoundException);
mlTaskManager
.updateMLTask(
taskId,
Map.of(STATE_FIELD, FAILED, ERROR_FIELD, jsonPathNotFoundException.getMessage()),
TASK_SEMAPHORE_TIMEOUT,
true
);
} catch (Exception e) {
log.error("Error in ingest, failed to produce a successRate", e);
mlTaskManager
.updateMLTask(
taskId,
Map.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(e)),
TASK_SEMAPHORE_TIMEOUT,
true
);
}
}

protected void handleSuccessRate(double successRate, String taskId) {
if (successRate == 100) {
mlTaskManager.updateMLTask(taskId, Map.of(STATE_FIELD, COMPLETED), 5000, true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ public class MachineLearningPlugin extends Plugin
public static final String TRAIN_THREAD_POOL = "opensearch_ml_train";
public static final String PREDICT_THREAD_POOL = "opensearch_ml_predict";
public static final String REMOTE_PREDICT_THREAD_POOL = "opensearch_ml_predict_remote";
public static final String INGEST_THREAD_POOL = "opensearch_ml_ingest";
public static final String REGISTER_THREAD_POOL = "opensearch_ml_register";
public static final String DEPLOY_THREAD_POOL = "opensearch_ml_deploy";
public static final String ML_BASE_URI = "/_plugins/_ml";
Expand Down Expand Up @@ -886,6 +887,14 @@ public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
ML_THREAD_POOL_PREFIX + REMOTE_PREDICT_THREAD_POOL,
false
);
FixedExecutorBuilder batchIngestThreadPool = new FixedExecutorBuilder(
settings,
INGEST_THREAD_POOL,
OpenSearchExecutors.allocatedProcessors(settings) * 4,
30,
ML_THREAD_POOL_PREFIX + INGEST_THREAD_POOL,
false
);

return ImmutableList
.of(
Expand All @@ -895,7 +904,8 @@ public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
executeThreadPool,
trainThreadPool,
predictThreadPool,
remotePredictThreadPool
remotePredictThreadPool,
batchIngestThreadPool
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ private MLCommonsSettings() {}
"^https://api\\.openai\\.com/.*$",
"^https://api\\.cohere\\.ai/.*$",
"^https://bedrock-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$",
"^https://bedrock-agent-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$"
"^https://bedrock-agent-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$",
"^https://bedrock\\..*[a-z0-9-]\\.amazonaws\\.com/.*$"
),
Function.identity(),
Setting.Property.NodeScope,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,13 +358,13 @@ private void runPredict(
&& tensorOutput.getMlModelOutputs() != null
&& !tensorOutput.getMlModelOutputs().isEmpty()) {
ModelTensors modelOutput = tensorOutput.getMlModelOutputs().get(0);
Integer statusCode = modelOutput.getStatusCode();
if (modelOutput.getMlModelTensors() != null && !modelOutput.getMlModelTensors().isEmpty()) {
Map<String, Object> dataAsMap = (Map<String, Object>) modelOutput
.getMlModelTensors()
.get(0)
.getDataAsMap();
if (dataAsMap != null
&& (dataAsMap.containsKey("TransformJobArn") || dataAsMap.containsKey("id"))) {
if (dataAsMap != null && statusCode != null && statusCode >= 200 && statusCode < 300) {
remoteJob.putAll(dataAsMap);
mlTask.setRemoteJob(remoteJob);
mlTask.setTaskId(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,29 @@
package org.opensearch.ml.action.batch;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.MLTask.ERROR_FIELD;
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
import static org.opensearch.ml.common.MLTaskState.COMPLETED;
import static org.opensearch.ml.common.MLTaskState.FAILED;
import static org.opensearch.ml.engine.ingest.S3DataIngestion.SOURCE;
import static org.opensearch.ml.plugin.MachineLearningPlugin.INGEST_THREAD_POOL;
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutorService;

import org.junit.Before;
import org.mockito.ArgumentCaptor;
Expand All @@ -45,6 +52,8 @@
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

import com.jayway.jsonpath.PathNotFoundException;

public class TransportBatchIngestionActionTests extends OpenSearchTestCase {
@Mock
private Client client;
Expand All @@ -62,6 +71,8 @@ public class TransportBatchIngestionActionTests extends OpenSearchTestCase {
ActionListener<MLBatchIngestionResponse> actionListener;
@Mock
ThreadPool threadPool;
@Mock
ExecutorService executorService;

private TransportBatchIngestionAction batchAction;
private MLBatchIngestionInput batchInput;
Expand Down Expand Up @@ -105,9 +116,42 @@ public void test_doExecute_success() {
listener.onResponse(indexResponse);
return null;
}).when(mlTaskManager).createMLTask(isA(MLTask.class), isA(ActionListener.class));
doReturn(executorService).when(threadPool).executor(INGEST_THREAD_POOL);
doAnswer(invocation -> {
Runnable runnable = invocation.getArgument(0);
runnable.run();
return null;
}).when(executorService).execute(any(Runnable.class));

batchAction.doExecute(task, mlBatchIngestionRequest, actionListener);

verify(actionListener).onResponse(any(MLBatchIngestionResponse.class));
verify(threadPool).executor(INGEST_THREAD_POOL);
}

public void test_doExecute_ExecuteWithNoErrorHandling() {
batchAction.executeWithErrorHandling(() -> {}, "taskId");

verify(mlTaskManager, never()).updateMLTask(anyString(), isA(Map.class), anyLong(), anyBoolean());
}

public void test_doExecute_ExecuteWithPathNotFoundException() {
batchAction.executeWithErrorHandling(() -> { throw new PathNotFoundException("jsonPath not found!"); }, "taskId");

verify(mlTaskManager)
.updateMLTask("taskId", Map.of(STATE_FIELD, FAILED, ERROR_FIELD, "jsonPath not found!"), TASK_SEMAPHORE_TIMEOUT, true);
}

public void test_doExecute_RuntimeException() {
batchAction.executeWithErrorHandling(() -> { throw new RuntimeException("runtime exception in the ingestion!"); }, "taskId");

verify(mlTaskManager)
.updateMLTask(
"taskId",
Map.of(STATE_FIELD, FAILED, ERROR_FIELD, "runtime exception in the ingestion!"),
TASK_SEMAPHORE_TIMEOUT,
true
);
}

public void test_doExecute_handleSuccessRate100() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -447,10 +447,9 @@ public void testValidateBatchPredictionSuccess() throws IOException {
"output",
"{\"properties\":{\"inference_results\":{\"description\":\"This is a test description field\"," + "\"type\":\"array\"}}}"
);
ModelTensorOutput modelTensorOutput = ModelTensorOutput
.builder()
.mlModelOutputs(List.of(ModelTensors.builder().mlModelTensors(List.of(modelTensor)).build()))
.build();
ModelTensors modelTensors = ModelTensors.builder().statusCode(200).mlModelTensors(List.of(modelTensor)).statusCode(200).build();
modelTensors.setStatusCode(200);
ModelTensorOutput modelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(List.of(modelTensors)).build();
doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(MLTaskResponse.builder().output(modelTensorOutput).build());
Expand Down

0 comments on commit 0b6252d

Please sign in to comment.