diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java index fdbfb52d0f..b49e075aea 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java @@ -28,7 +28,7 @@ public class MLPreProcessFunction { " }\n" + " }\n" + " builder.append(\"]\");\n" + - " def parameters = \"{\" +\"\\\"prompt\\\":\" + builder + \"}\";\n" + + " def parameters = \"{\" +\"\\\"texts\\\":\" + builder + \"}\";\n" + " return \"{\" +\"\\\"parameters\\\":\" + parameters + \"}\";"); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, "\n StringBuilder builder = new StringBuilder();\n" + diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 0ea0ca0724..3a623a1b65 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -119,8 +119,6 @@ public void dispatchTask( ActionListener listener ) { String modelId = request.getModelId(); - MLInput input = request.getMlInput(); - FunctionName algorithm = input.getAlgorithm(); try { ActionListener actionListener = ActionListener.wrap(node -> { if (clusterService.localNode().getId().equals(node.getId())) { @@ -133,9 +131,9 @@ public void dispatchTask( transportService.sendRequest(node, getTransportActionName(), request, getResponseHandler(listener)); } }, e -> { listener.onFailure(e); }); - String[] workerNodes = mlModelManager.getWorkerNodes(modelId, algorithm, true); + String[] workerNodes = mlModelManager.getWorkerNodes(modelId, functionName, true); if (workerNodes == null || workerNodes.length == 0) { - if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.REMOTE) { + if (functionName == FunctionName.TEXT_EMBEDDING || functionName == FunctionName.REMOTE) { listener .onFailure( new IllegalArgumentException( @@ -144,7 +142,7 @@ public void dispatchTask( ); return; } else { - workerNodes = nodeHelper.getEligibleNodeIds(algorithm); + workerNodes = nodeHelper.getEligibleNodeIds(functionName); } } mlTaskDispatcher.dispatchPredictTask(workerNodes, actionListener); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index 26c19394d2..6e87cdae0a 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -212,7 +212,7 @@ public void setup() throws IOException { public void testExecuteTask_OnLocalNode() { setupMocks(true, false, false, false); - taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); // verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); @@ -220,10 +220,22 @@ public void testExecuteTask_OnLocalNode() { verify(mlTaskManager).remove(anyString()); } + public void testExecuteTask_OnLocalNode_RemoteModel() { + setupMocks(true, false, false, false); + + taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener); + verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(argumentCaptor.capture()); + assertTrue(argumentCaptor.getValue().getMessage().contains("Model not ready yet.")); + verify(mlTaskManager, never()).add(any(MLTask.class)); + verify(client, never()).get(any(), any()); + } + public void testExecuteTask_OnLocalNode_QueryInput() { setupMocks(true, false, false, false); - taskRunner.dispatchTask(FunctionName.REMOTE, requestWithQuery, transportService, listener); + taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithQuery, transportService, listener); verify(mlInputDatasetHandler).parseSearchQueryInput(any(), any()); // verify(mlInputDatasetHandler, never()).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); @@ -234,7 +246,7 @@ public void testExecuteTask_OnLocalNode_QueryInput() { public void testExecuteTask_OnLocalNode_QueryInput_Failure() { setupMocks(true, true, false, false); - taskRunner.dispatchTask(FunctionName.REMOTE, requestWithQuery, transportService, listener); + taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithQuery, transportService, listener); verify(mlInputDatasetHandler).parseSearchQueryInput(any(), any()); // verify(mlInputDatasetHandler, never()).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager, never()).add(any(MLTask.class)); @@ -245,7 +257,7 @@ public void testExecuteTask_NoPermission() { setupMocks(true, true, false, false); threadContext.stashContext(); threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "test_user|test_role|test_tenant"); - taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlTaskManager).add(any(MLTask.class)); verify(mlTaskManager).remove(anyString()); verify(client).get(any(), any()); @@ -256,14 +268,14 @@ public void testExecuteTask_NoPermission() { public void testExecuteTask_OnRemoteNode() { setupMocks(false, false, false, false); - taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(transportService).sendRequest(eq(remoteNode), eq(MLPredictionTaskAction.NAME), eq(requestWithDataFrame), any()); } public void testExecuteTask_OnLocalNode_GetModelFail() { setupMocks(true, false, true, false); - taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); // verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); @@ -277,7 +289,7 @@ public void testExecuteTask_OnLocalNode_NullModelIdException() { setupMocks(true, false, false, false); requestWithDataFrame = MLPredictionTaskRequest.builder().mlInput(mlInputWithDataFrame).build(); - taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); // verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); @@ -291,7 +303,7 @@ public void testExecuteTask_OnLocalNode_NullModelIdException() { public void testExecuteTask_OnLocalNode_NullGetResponse() { setupMocks(true, false, false, true); - taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); // verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class));