diff --git a/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java b/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java index 276ce1774e..4e956e9c7e 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java @@ -10,12 +10,16 @@ import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT; import static org.opensearch.ml.utils.MLExceptionUtils.logException; import static org.opensearch.ml.utils.MLExceptionUtils.toJsonString; +import static org.opensearch.ml.utils.RestActionUtils.getAllNodes; import java.time.Instant; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; import org.opensearch.action.ActionRequest; import org.opensearch.action.support.ActionFilters; @@ -131,10 +135,29 @@ protected void doExecute(Task task, ActionRequest request, ActionListener workNodesRemovedFromCluster = new HashSet<>(); + + if (workNodes != null && !workNodes.isEmpty()) { + Set allNodesInCluster = new HashSet<>(List.of(getAllNodes(clusterService))); + + workNodesRemovedFromCluster = workNodes + .stream() + .filter(node -> !allNodesInCluster.contains(node)) + .collect(Collectors.toSet()); + + if (!workNodesRemovedFromCluster.isEmpty()) { + workNodes.removeAll(workNodesRemovedFromCluster); + } + } + + if (workNodes == null || workNodes.isEmpty()) { + if (!workNodesRemovedFromCluster.isEmpty()) { + mlTaskCache.updateWorkerNode(workNodesRemovedFromCluster); + mlModelManager.removeModelWorkerNode(modelId, false, workNodesRemovedFromCluster.toArray(new String[0])); + } int currentWorkerNodeCount = mlTaskCache.getWorkerNodeSize(); MLTaskState taskState = mlTaskCache.hasError() ? MLTaskState.COMPLETED_WITH_ERROR : MLTaskState.COMPLETED; - if (mlTaskCache.allNodeFailed()) { + if (mlTaskCache.allNodeFailed() || mlTaskCache.getWorkerNodeSize() == 0) { taskState = MLTaskState.FAILED; currentWorkerNodeCount = 0; } else { @@ -150,11 +173,11 @@ protected void doExecute(Task task, ActionRequest request, ActionListener updateFields = new HashMap<>(); updateFields.put(MLModel.MODEL_STATE_FIELD, modelState); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskCache.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskCache.java index a6b519a2aa..0b3d7d116f 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskCache.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskCache.java @@ -62,4 +62,9 @@ public int errorNodesCount() { public boolean allNodeFailed() { return workerNodeSize != null && errors.size() == workerNodeSize; } + + public void updateWorkerNode(Set nodesRemovedFromCluster) { + this.workerNodes.removeAll(nodesRemovedFromCluster); + this.workerNodeSize = this.workerNodeSize - nodesRemovedFromCluster.size(); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/forward/TransportForwardActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/forward/TransportForwardActionTests.java index b5c1a4e1f7..62a4fa779f 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/forward/TransportForwardActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/forward/TransportForwardActionTests.java @@ -29,6 +29,7 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ONLY_RUN_ON_ML_NODE; import static org.opensearch.ml.utils.TestHelper.ML_ROLE; import static org.opensearch.ml.utils.TestHelper.clusterSetting; +import static org.opensearch.ml.utils.TestHelper.setupTestClusterState; import java.util.Arrays; import java.util.HashSet; @@ -43,6 +44,7 @@ import org.opensearch.Version; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; @@ -94,6 +96,8 @@ public class TransportForwardActionTests extends OpenSearchTestCase { private TransportForwardAction forwardAction; + private ClusterState testState; + Settings settings = Settings .builder() .put(ML_COMMONS_MODEL_AUTO_REDEPLOY_ENABLE.getKey(), true) @@ -137,6 +141,9 @@ public void setup() { ) ); + testState = setupTestClusterState("test_node_id2"); + when(clusterService.state()).thenReturn(testState); + node1 = new DiscoveryNode(nodeId1, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT); node2 = new DiscoveryNode(nodeId2, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT); diff --git a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java index c40b11e1ce..523b8fee36 100644 --- a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java +++ b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java @@ -118,7 +118,7 @@ public void setup() throws IOException { encryptor = spy(new EncryptorImpl(null)); syncUpCron = new MLSyncUpCron(client, clusterService, nodeHelper, mlIndicesHandler, encryptor); - testState = setupTestClusterState(); + testState = setupTestClusterState("node"); when(clusterService.state()).thenReturn(testState); doAnswer(invocation -> { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLProfileActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLProfileActionTests.java index 446dc74213..4c644300f2 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLProfileActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLProfileActionTests.java @@ -151,7 +151,7 @@ public void setup() throws IOException { .build(); clusterName = new ClusterName("test cluster"); - testState = setupTestClusterState(); + testState = setupTestClusterState("node"); when(clusterService.state()).thenReturn(testState); doAnswer(invocation -> { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java index cdd9136255..a034af502a 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java @@ -147,7 +147,7 @@ public void setup() throws IOException { roleSet, Version.CURRENT ); - testState = setupTestClusterState(); + testState = setupTestClusterState("node"); when(clusterService.state()).thenReturn(testState); clusterName = new ClusterName(clusterNameStr); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUndeployModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUndeployModelActionTests.java index cf8e93765c..e98d1a5f31 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUndeployModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUndeployModelActionTests.java @@ -67,7 +67,7 @@ public class RestMLUndeployModelActionTests extends OpenSearchTestCase { @Before public void setup() { MockitoAnnotations.openMocks(this); - testState = setupTestClusterState(); + testState = setupTestClusterState("node"); when(clusterService.state()).thenReturn(testState); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); restMLUndeployModelAction = new RestMLUndeployModelAction(clusterService, settings); diff --git a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java index 513b2497eb..2632acca0b 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java @@ -461,11 +461,11 @@ public static ClusterState state(int numDataNodes, String indexName, String mapp return state(new ClusterName("test"), indexName, mapping, clusterManagerNode, clusterManagerNode, allNodes); } - public static ClusterState setupTestClusterState() { + public static ClusterState setupTestClusterState(String nodeId) { Set roleSet = new HashSet<>(); roleSet.add(DiscoveryNodeRole.DATA_ROLE); DiscoveryNode node = new DiscoveryNode( - "node", + nodeId, new TransportAddress(TransportAddress.META_ADDRESS, new AtomicInteger().incrementAndGet()), new HashMap<>(), roleSet, diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAProcessorConstants.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAProcessorConstants.java index ff71cadba2..7b3cf07db8 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAProcessorConstants.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAProcessorConstants.java @@ -43,4 +43,8 @@ public class GenerativeQAProcessorConstants { .boolSetting("plugins.ml_commons.rag_pipeline_feature_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic); public static final String FEATURE_NOT_ENABLED_ERROR_MSG = RAG_PIPELINE_FEATURE_ENABLED.getKey() + " is not enabled."; + + public static final String RAG_NULL_GEN_QA_PARAMS_ERROR_MSG = "generative_qa_parameters not found." + + " Please provide ext.generative_qa_parameters to proceed." + + " For more info, refer: https://opensearch.org/docs/latest/search-plugins/conversational-search/#step-6-use-the-pipeline-for-rag"; } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java index 7b1814c2a5..5ac106fb51 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java @@ -18,6 +18,7 @@ package org.opensearch.searchpipelines.questionanswering.generative; import static org.opensearch.ingest.ConfigurationUtils.newConfigurationException; +import static org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants.RAG_NULL_GEN_QA_PARAMS_ERROR_MSG; import java.time.Duration; import java.time.Instant; @@ -126,6 +127,9 @@ public void processResponseAsync( } GenerativeQAParameters params = GenerativeQAParamUtil.getGenerativeQAParameters(request); + if (params == null) { + throw new IllegalArgumentException(RAG_NULL_GEN_QA_PARAMS_ERROR_MSG); + } Integer t = params.getTimeout(); if (t == null || t == GenerativeQAParameters.SIZE_NULL_VALUE) { diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java index a89b5c1731..4295ab450e 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java @@ -24,6 +24,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants.RAG_NULL_GEN_QA_PARAMS_ERROR_MSG; import java.time.Instant; import java.util.Collections; @@ -646,6 +647,77 @@ public void testProcessResponseNullValueInteractions() throws Exception { })); } + public void testProcessResponseIllegalArgumentForNullParams() throws Exception { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage(RAG_NULL_GEN_QA_PARAMS_ERROR_MSG); + + Client client = mock(Client.class); + Map config = new HashMap<>(); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "dummy-model"); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text")); + + GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory( + client, + alwaysOn + ).create(null, "tag", "desc", true, config, null); + + ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class); + List chatHistory = List + .of( + new Interaction( + "0", + Instant.now(), + "1", + "question", + "", + "answer", + "foo", + Collections.singletonMap("meta data", "some meta") + ) + ); + doAnswer(invocation -> { + ((ActionListener>) invocation.getArguments()[2]).onResponse(chatHistory); + return null; + }).when(memoryClient).getInteractions(any(), anyInt(), any()); + processor.setMemoryClient(memoryClient); + + SearchRequest request = new SearchRequest(); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); + extBuilder.setParams(null); + request.source(sourceBuilder); + sourceBuilder.ext(List.of(extBuilder)); + + int numHits = 10; + SearchHit[] hitsArray = new SearchHit[numHits]; + for (int i = 0; i < numHits; i++) { + XContentBuilder sourceContent = JsonXContent + .contentBuilder() + .startObject() + .field("_id", String.valueOf(i)) + .field("text", "passage" + i) + .field("title", "This is the title for document " + i) + .endObject(); + hitsArray[i] = new SearchHit(i, "doc" + i, Map.of(), Map.of()); + hitsArray[i].sourceRef(BytesReference.bytes(sourceContent)); + } + + SearchHits searchHits = new SearchHits(hitsArray, null, 1.0f); + SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + SearchResponse response = new SearchResponse(internal, null, 1, 1, 0, 1, null, null, null); + + Llm llm = mock(Llm.class); + processor.setLlm(llm); + + processor + .processResponseAsync( + request, + response, + null, + ActionListener.wrap(r -> { assertTrue(r instanceof GenerativeSearchResponse); }, e -> {}) + ); + } + public void testProcessResponseIllegalArgument() throws Exception { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("llm_model cannot be null.");