From e25134c384204609ed0c8e228baeee861b7e53da Mon Sep 17 00:00:00 2001 From: Xun Zhang Date: Thu, 12 Sep 2024 13:05:51 -0700 Subject: [PATCH] add beckrock url in the allowed list and more UTs Signed-off-by: Xun Zhang --- .../engine/ingest/AbstractIngestionTests.java | 29 ++++++++++++ .../ml/settings/MLCommonsSettings.java | 3 +- .../TransportBatchIngestionActionTests.java | 44 +++++++++++++++++++ .../ml/task/MLPredictTaskRunnerTests.java | 7 ++- 4 files changed, 78 insertions(+), 5 deletions(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/ingest/AbstractIngestionTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/ingest/AbstractIngestionTests.java index a4c155ba77..1f1653b31c 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/ingest/AbstractIngestionTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/ingest/AbstractIngestionTests.java @@ -204,6 +204,35 @@ public void testFilterFieldMapping_MatchingPrefix() { assertEquals(Arrays.asList("$.custom_id"), result.get("_id")); } + @Test + public void testFilterFieldMappingSoleSource_MatchingPrefix() { + // Arrange + Map fieldMap = new HashMap<>(); + fieldMap.put("question", "source[0].$.body.input[0]"); + fieldMap.put("question_embedding", "source[0].$.response.body.data[0].embedding"); + fieldMap.put("answer", "source[0].$.body.input[1]"); + fieldMap.put("answer_embedding", "$.response.body.data[1].embedding"); + fieldMap.put("_id", Arrays.asList("$.custom_id", "source[1].$.custom_id")); + + MLBatchIngestionInput mlBatchIngestionInput = new MLBatchIngestionInput( + "indexName", + fieldMap, + ingestFields, + new HashMap<>(), + new HashMap<>() + ); + + // Act + Map result = s3DataIngestion.filterFieldMappingSoleSource(mlBatchIngestionInput); + + // Assert + assertEquals(6, result.size()); + + assertEquals("$.body.input[0]", result.get("question")); + assertEquals("$.response.body.data[0].embedding", result.get("question_embedding")); + assertEquals(Arrays.asList("$.custom_id"), result.get("_id")); + } + @Test public void testProcessFieldMapping_FromSM() { String jsonStr = diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 6daffd30fd..339116226d 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -146,7 +146,8 @@ private MLCommonsSettings() {} "^https://api\\.openai\\.com/.*$", "^https://api\\.cohere\\.ai/.*$", "^https://bedrock-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$", - "^https://bedrock-agent-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$" + "^https://bedrock-agent-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$", + "^https://bedrock\\..*[a-z0-9-]\\.amazonaws\\.com/.*$" ), Function.identity(), Setting.Property.NodeScope, diff --git a/plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java index 2916359110..092edfe951 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java @@ -6,9 +6,14 @@ package org.opensearch.ml.action.batch; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.common.MLTask.ERROR_FIELD; @@ -16,12 +21,14 @@ import static org.opensearch.ml.common.MLTaskState.COMPLETED; import static org.opensearch.ml.common.MLTaskState.FAILED; import static org.opensearch.ml.engine.ingest.S3DataIngestion.SOURCE; +import static org.opensearch.ml.plugin.MachineLearningPlugin.INGEST_THREAD_POOL; import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.ExecutorService; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -45,6 +52,8 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; +import com.jayway.jsonpath.PathNotFoundException; + public class TransportBatchIngestionActionTests extends OpenSearchTestCase { @Mock private Client client; @@ -62,6 +71,8 @@ public class TransportBatchIngestionActionTests extends OpenSearchTestCase { ActionListener actionListener; @Mock ThreadPool threadPool; + @Mock + ExecutorService executorService; private TransportBatchIngestionAction batchAction; private MLBatchIngestionInput batchInput; @@ -105,9 +116,42 @@ public void test_doExecute_success() { listener.onResponse(indexResponse); return null; }).when(mlTaskManager).createMLTask(isA(MLTask.class), isA(ActionListener.class)); + doReturn(executorService).when(threadPool).executor(INGEST_THREAD_POOL); + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + return null; + }).when(executorService).execute(any(Runnable.class)); + batchAction.doExecute(task, mlBatchIngestionRequest, actionListener); verify(actionListener).onResponse(any(MLBatchIngestionResponse.class)); + verify(threadPool).executor(INGEST_THREAD_POOL); + } + + public void test_doExecute_ExecuteWithNoErrorHandling() { + batchAction.executeWithErrorHandling(() -> {}, "taskId"); + + verify(mlTaskManager, never()).updateMLTask(anyString(), isA(Map.class), anyLong(), anyBoolean()); + } + + public void test_doExecute_ExecuteWithPathNotFoundException() { + batchAction.executeWithErrorHandling(() -> { throw new PathNotFoundException("jsonPath not found!"); }, "taskId"); + + verify(mlTaskManager) + .updateMLTask("taskId", Map.of(STATE_FIELD, FAILED, ERROR_FIELD, "jsonPath not found!"), TASK_SEMAPHORE_TIMEOUT, true); + } + + public void test_doExecute_RuntimeException() { + batchAction.executeWithErrorHandling(() -> { throw new RuntimeException("runtime exception in the ingestion!"); }, "taskId"); + + verify(mlTaskManager) + .updateMLTask( + "taskId", + Map.of(STATE_FIELD, FAILED, ERROR_FIELD, "runtime exception in the ingestion!"), + TASK_SEMAPHORE_TIMEOUT, + true + ); } public void test_doExecute_handleSuccessRate100() { diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index 064008a9c4..223f2ce5a5 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -447,10 +447,9 @@ public void testValidateBatchPredictionSuccess() throws IOException { "output", "{\"properties\":{\"inference_results\":{\"description\":\"This is a test description field\"," + "\"type\":\"array\"}}}" ); - ModelTensorOutput modelTensorOutput = ModelTensorOutput - .builder() - .mlModelOutputs(List.of(ModelTensors.builder().mlModelTensors(List.of(modelTensor)).build())) - .build(); + ModelTensors modelTensors = ModelTensors.builder().statusCode(200).mlModelTensors(List.of(modelTensor)).statusCode(200).build(); + modelTensors.setStatusCode(200); + ModelTensorOutput modelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(List.of(modelTensors)).build(); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(MLTaskResponse.builder().output(modelTensorOutput).build());