Skip to content

Commit

Permalink
fix failed unit test
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Sep 26, 2023
1 parent d5d3aa0 commit d49acc0
Showing 1 changed file with 20 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -212,18 +212,30 @@ public void setup() throws IOException {
public void testExecuteTask_OnLocalNode() {
setupMocks(true, false, false, false);

taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener);
taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener);
verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any());
// verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset());
verify(mlTaskManager).add(any(MLTask.class));
verify(client).get(any(), any());
verify(mlTaskManager).remove(anyString());
}

public void testExecuteTask_OnLocalNode_RemoteModel() {
setupMocks(true, false, false, false);

taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener);
verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any());
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener).onFailure(argumentCaptor.capture());
assertTrue(argumentCaptor.getValue().getMessage().contains("Model not ready yet."));
verify(mlTaskManager, never()).add(any(MLTask.class));
verify(client, never()).get(any(), any());
}

public void testExecuteTask_OnLocalNode_QueryInput() {
setupMocks(true, false, false, false);

taskRunner.dispatchTask(FunctionName.REMOTE, requestWithQuery, transportService, listener);
taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithQuery, transportService, listener);
verify(mlInputDatasetHandler).parseSearchQueryInput(any(), any());
// verify(mlInputDatasetHandler, never()).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset());
verify(mlTaskManager).add(any(MLTask.class));
Expand All @@ -234,7 +246,7 @@ public void testExecuteTask_OnLocalNode_QueryInput() {
public void testExecuteTask_OnLocalNode_QueryInput_Failure() {
setupMocks(true, true, false, false);

taskRunner.dispatchTask(FunctionName.REMOTE, requestWithQuery, transportService, listener);
taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithQuery, transportService, listener);
verify(mlInputDatasetHandler).parseSearchQueryInput(any(), any());
// verify(mlInputDatasetHandler, never()).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset());
verify(mlTaskManager, never()).add(any(MLTask.class));
Expand All @@ -245,7 +257,7 @@ public void testExecuteTask_NoPermission() {
setupMocks(true, true, false, false);
threadContext.stashContext();
threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "test_user|test_role|test_tenant");
taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener);
taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener);
verify(mlTaskManager).add(any(MLTask.class));
verify(mlTaskManager).remove(anyString());
verify(client).get(any(), any());
Expand All @@ -256,14 +268,14 @@ public void testExecuteTask_NoPermission() {

public void testExecuteTask_OnRemoteNode() {
setupMocks(false, false, false, false);
taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener);
taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener);
verify(transportService).sendRequest(eq(remoteNode), eq(MLPredictionTaskAction.NAME), eq(requestWithDataFrame), any());
}

public void testExecuteTask_OnLocalNode_GetModelFail() {
setupMocks(true, false, true, false);

taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener);
taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener);
verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any());
// verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset());
verify(mlTaskManager).add(any(MLTask.class));
Expand All @@ -277,7 +289,7 @@ public void testExecuteTask_OnLocalNode_NullModelIdException() {
setupMocks(true, false, false, false);
requestWithDataFrame = MLPredictionTaskRequest.builder().mlInput(mlInputWithDataFrame).build();

taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener);
taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener);
verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any());
// verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset());
verify(mlTaskManager).add(any(MLTask.class));
Expand All @@ -291,7 +303,7 @@ public void testExecuteTask_OnLocalNode_NullModelIdException() {
public void testExecuteTask_OnLocalNode_NullGetResponse() {
setupMocks(true, false, false, true);

taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener);
taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener);
verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any());
// verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset());
verify(mlTaskManager).add(any(MLTask.class));
Expand Down

0 comments on commit d49acc0

Please sign in to comment.