Skip to content

Commit

Permalink
add beckrock url in the allowed list and more UTs
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed Sep 12, 2024
1 parent 57051bd commit e25134c
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,35 @@ public void testFilterFieldMapping_MatchingPrefix() {
assertEquals(Arrays.asList("$.custom_id"), result.get("_id"));
}

@Test
public void testFilterFieldMappingSoleSource_MatchingPrefix() {
// Arrange
Map<String, Object> fieldMap = new HashMap<>();
fieldMap.put("question", "source[0].$.body.input[0]");
fieldMap.put("question_embedding", "source[0].$.response.body.data[0].embedding");
fieldMap.put("answer", "source[0].$.body.input[1]");
fieldMap.put("answer_embedding", "$.response.body.data[1].embedding");
fieldMap.put("_id", Arrays.asList("$.custom_id", "source[1].$.custom_id"));

MLBatchIngestionInput mlBatchIngestionInput = new MLBatchIngestionInput(
"indexName",
fieldMap,
ingestFields,
new HashMap<>(),
new HashMap<>()
);

// Act
Map<String, Object> result = s3DataIngestion.filterFieldMappingSoleSource(mlBatchIngestionInput);

// Assert
assertEquals(6, result.size());

assertEquals("$.body.input[0]", result.get("question"));
assertEquals("$.response.body.data[0].embedding", result.get("question_embedding"));
assertEquals(Arrays.asList("$.custom_id"), result.get("_id"));
}

@Test
public void testProcessFieldMapping_FromSM() {
String jsonStr =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ private MLCommonsSettings() {}
"^https://api\\.openai\\.com/.*$",
"^https://api\\.cohere\\.ai/.*$",
"^https://bedrock-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$",
"^https://bedrock-agent-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$"
"^https://bedrock-agent-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$",
"^https://bedrock\\..*[a-z0-9-]\\.amazonaws\\.com/.*$"
),
Function.identity(),
Setting.Property.NodeScope,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,29 @@
package org.opensearch.ml.action.batch;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.MLTask.ERROR_FIELD;
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
import static org.opensearch.ml.common.MLTaskState.COMPLETED;
import static org.opensearch.ml.common.MLTaskState.FAILED;
import static org.opensearch.ml.engine.ingest.S3DataIngestion.SOURCE;
import static org.opensearch.ml.plugin.MachineLearningPlugin.INGEST_THREAD_POOL;
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutorService;

import org.junit.Before;
import org.mockito.ArgumentCaptor;
Expand All @@ -45,6 +52,8 @@
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

import com.jayway.jsonpath.PathNotFoundException;

public class TransportBatchIngestionActionTests extends OpenSearchTestCase {
@Mock
private Client client;
Expand All @@ -62,6 +71,8 @@ public class TransportBatchIngestionActionTests extends OpenSearchTestCase {
ActionListener<MLBatchIngestionResponse> actionListener;
@Mock
ThreadPool threadPool;
@Mock
ExecutorService executorService;

private TransportBatchIngestionAction batchAction;
private MLBatchIngestionInput batchInput;
Expand Down Expand Up @@ -105,9 +116,42 @@ public void test_doExecute_success() {
listener.onResponse(indexResponse);
return null;
}).when(mlTaskManager).createMLTask(isA(MLTask.class), isA(ActionListener.class));
doReturn(executorService).when(threadPool).executor(INGEST_THREAD_POOL);
doAnswer(invocation -> {
Runnable runnable = invocation.getArgument(0);
runnable.run();
return null;
}).when(executorService).execute(any(Runnable.class));

batchAction.doExecute(task, mlBatchIngestionRequest, actionListener);

verify(actionListener).onResponse(any(MLBatchIngestionResponse.class));
verify(threadPool).executor(INGEST_THREAD_POOL);
}

public void test_doExecute_ExecuteWithNoErrorHandling() {
batchAction.executeWithErrorHandling(() -> {}, "taskId");

verify(mlTaskManager, never()).updateMLTask(anyString(), isA(Map.class), anyLong(), anyBoolean());
}

public void test_doExecute_ExecuteWithPathNotFoundException() {
batchAction.executeWithErrorHandling(() -> { throw new PathNotFoundException("jsonPath not found!"); }, "taskId");

verify(mlTaskManager)
.updateMLTask("taskId", Map.of(STATE_FIELD, FAILED, ERROR_FIELD, "jsonPath not found!"), TASK_SEMAPHORE_TIMEOUT, true);
}

public void test_doExecute_RuntimeException() {
batchAction.executeWithErrorHandling(() -> { throw new RuntimeException("runtime exception in the ingestion!"); }, "taskId");

verify(mlTaskManager)
.updateMLTask(
"taskId",
Map.of(STATE_FIELD, FAILED, ERROR_FIELD, "runtime exception in the ingestion!"),
TASK_SEMAPHORE_TIMEOUT,
true
);
}

public void test_doExecute_handleSuccessRate100() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -447,10 +447,9 @@ public void testValidateBatchPredictionSuccess() throws IOException {
"output",
"{\"properties\":{\"inference_results\":{\"description\":\"This is a test description field\"," + "\"type\":\"array\"}}}"
);
ModelTensorOutput modelTensorOutput = ModelTensorOutput
.builder()
.mlModelOutputs(List.of(ModelTensors.builder().mlModelTensors(List.of(modelTensor)).build()))
.build();
ModelTensors modelTensors = ModelTensors.builder().statusCode(200).mlModelTensors(List.of(modelTensor)).statusCode(200).build();
modelTensors.setStatusCode(200);
ModelTensorOutput modelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(List.of(modelTensors)).build();
doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(MLTaskResponse.builder().output(modelTensorOutput).build());
Expand Down

0 comments on commit e25134c

Please sign in to comment.