Skip to content

Commit

Permalink
Merge branch '2.16' into bwc_rag_request_params
Browse files Browse the repository at this point in the history
  • Loading branch information
pyek-bot committed Dec 30, 2024
2 parents b426423 + af0a137 commit 2988c87
Show file tree
Hide file tree
Showing 11 changed files with 126 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
import static org.opensearch.ml.utils.MLExceptionUtils.logException;
import static org.opensearch.ml.utils.MLExceptionUtils.toJsonString;
import static org.opensearch.ml.utils.RestActionUtils.getAllNodes;

import java.time.Instant;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.support.ActionFilters;
Expand Down Expand Up @@ -131,10 +135,29 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
syncModelWorkerNodes(modelId, functionName);
}

if (workNodes == null || workNodes.size() == 0) {
Set<String> workNodesRemovedFromCluster = new HashSet<>();

if (workNodes != null && !workNodes.isEmpty()) {
Set<String> allNodesInCluster = new HashSet<>(List.of(getAllNodes(clusterService)));

workNodesRemovedFromCluster = workNodes
.stream()
.filter(node -> !allNodesInCluster.contains(node))
.collect(Collectors.toSet());

if (!workNodesRemovedFromCluster.isEmpty()) {
workNodes.removeAll(workNodesRemovedFromCluster);
}
}

if (workNodes == null || workNodes.isEmpty()) {
if (!workNodesRemovedFromCluster.isEmpty()) {
mlTaskCache.updateWorkerNode(workNodesRemovedFromCluster);
mlModelManager.removeModelWorkerNode(modelId, false, workNodesRemovedFromCluster.toArray(new String[0]));
}
int currentWorkerNodeCount = mlTaskCache.getWorkerNodeSize();
MLTaskState taskState = mlTaskCache.hasError() ? MLTaskState.COMPLETED_WITH_ERROR : MLTaskState.COMPLETED;
if (mlTaskCache.allNodeFailed()) {
if (mlTaskCache.allNodeFailed() || mlTaskCache.getWorkerNodeSize() == 0) {
taskState = MLTaskState.FAILED;
currentWorkerNodeCount = 0;
} else {
Expand All @@ -150,11 +173,11 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
mlTaskManager.updateMLTask(taskId, builder.build(), TASK_SEMAPHORE_TIMEOUT, true);

MLModelState modelState;
if (!mlTaskCache.allNodeFailed()) {
modelState = mlTaskCache.hasError() ? MLModelState.PARTIALLY_DEPLOYED : MLModelState.DEPLOYED;
} else {
if (mlTaskCache.allNodeFailed() || mlTaskCache.getWorkerNodeSize() == 0) {
modelState = MLModelState.DEPLOY_FAILED;
log.error("deploy model failed on all nodes, model id: {}", modelId);
} else {
modelState = mlTaskCache.hasError() ? MLModelState.PARTIALLY_DEPLOYED : MLModelState.DEPLOYED;
}
Map<String, Object> updateFields = new HashMap<>();
updateFields.put(MLModel.MODEL_STATE_FIELD, modelState);
Expand Down
5 changes: 5 additions & 0 deletions plugin/src/main/java/org/opensearch/ml/task/MLTaskCache.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,9 @@ public int errorNodesCount() {
public boolean allNodeFailed() {
return workerNodeSize != null && errors.size() == workerNodeSize;
}

public void updateWorkerNode(Set<String> nodesRemovedFromCluster) {
this.workerNodes.removeAll(nodesRemovedFromCluster);
this.workerNodeSize = this.workerNodeSize - nodesRemovedFromCluster.size();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ONLY_RUN_ON_ML_NODE;
import static org.opensearch.ml.utils.TestHelper.ML_ROLE;
import static org.opensearch.ml.utils.TestHelper.clusterSetting;
import static org.opensearch.ml.utils.TestHelper.setupTestClusterState;

import java.util.Arrays;
import java.util.HashSet;
Expand All @@ -43,6 +44,7 @@
import org.opensearch.Version;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
Expand Down Expand Up @@ -94,6 +96,8 @@ public class TransportForwardActionTests extends OpenSearchTestCase {

private TransportForwardAction forwardAction;

private ClusterState testState;

Settings settings = Settings
.builder()
.put(ML_COMMONS_MODEL_AUTO_REDEPLOY_ENABLE.getKey(), true)
Expand Down Expand Up @@ -137,6 +141,9 @@ public void setup() {
)
);

testState = setupTestClusterState("test_node_id2");
when(clusterService.state()).thenReturn(testState);

node1 = new DiscoveryNode(nodeId1, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT);
node2 = new DiscoveryNode(nodeId2, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ public void setup() throws IOException {
encryptor = spy(new EncryptorImpl(null));
syncUpCron = new MLSyncUpCron(client, clusterService, nodeHelper, mlIndicesHandler, encryptor);

testState = setupTestClusterState();
testState = setupTestClusterState("node");
when(clusterService.state()).thenReturn(testState);

doAnswer(invocation -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ public void setup() throws IOException {
.build();

clusterName = new ClusterName("test cluster");
testState = setupTestClusterState();
testState = setupTestClusterState("node");
when(clusterService.state()).thenReturn(testState);

doAnswer(invocation -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ public void setup() throws IOException {
roleSet,
Version.CURRENT
);
testState = setupTestClusterState();
testState = setupTestClusterState("node");
when(clusterService.state()).thenReturn(testState);

clusterName = new ClusterName(clusterNameStr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public class RestMLUndeployModelActionTests extends OpenSearchTestCase {
@Before
public void setup() {
MockitoAnnotations.openMocks(this);
testState = setupTestClusterState();
testState = setupTestClusterState("node");
when(clusterService.state()).thenReturn(testState);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
restMLUndeployModelAction = new RestMLUndeployModelAction(clusterService, settings);
Expand Down
4 changes: 2 additions & 2 deletions plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -461,11 +461,11 @@ public static ClusterState state(int numDataNodes, String indexName, String mapp
return state(new ClusterName("test"), indexName, mapping, clusterManagerNode, clusterManagerNode, allNodes);
}

public static ClusterState setupTestClusterState() {
public static ClusterState setupTestClusterState(String nodeId) {
Set<DiscoveryNodeRole> roleSet = new HashSet<>();
roleSet.add(DiscoveryNodeRole.DATA_ROLE);
DiscoveryNode node = new DiscoveryNode(
"node",
nodeId,
new TransportAddress(TransportAddress.META_ADDRESS, new AtomicInteger().incrementAndGet()),
new HashMap<>(),
roleSet,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,8 @@ public class GenerativeQAProcessorConstants {
.boolSetting("plugins.ml_commons.rag_pipeline_feature_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final String FEATURE_NOT_ENABLED_ERROR_MSG = RAG_PIPELINE_FEATURE_ENABLED.getKey() + " is not enabled.";

public static final String RAG_NULL_GEN_QA_PARAMS_ERROR_MSG = "generative_qa_parameters not found."
+ " Please provide ext.generative_qa_parameters to proceed."
+ " For more info, refer: https://opensearch.org/docs/latest/search-plugins/conversational-search/#step-6-use-the-pipeline-for-rag";
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.opensearch.searchpipelines.questionanswering.generative;

import static org.opensearch.ingest.ConfigurationUtils.newConfigurationException;
import static org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants.RAG_NULL_GEN_QA_PARAMS_ERROR_MSG;

import java.time.Duration;
import java.time.Instant;
Expand Down Expand Up @@ -126,6 +127,9 @@ public void processResponseAsync(
}

GenerativeQAParameters params = GenerativeQAParamUtil.getGenerativeQAParameters(request);
if (params == null) {
throw new IllegalArgumentException(RAG_NULL_GEN_QA_PARAMS_ERROR_MSG);
}

Integer t = params.getTimeout();
if (t == null || t == GenerativeQAParameters.SIZE_NULL_VALUE) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants.RAG_NULL_GEN_QA_PARAMS_ERROR_MSG;

import java.time.Instant;
import java.util.Collections;
Expand Down Expand Up @@ -646,6 +647,77 @@ public void testProcessResponseNullValueInteractions() throws Exception {
}));
}

public void testProcessResponseIllegalArgumentForNullParams() throws Exception {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage(RAG_NULL_GEN_QA_PARAMS_ERROR_MSG);

Client client = mock(Client.class);
Map<String, Object> config = new HashMap<>();
config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "dummy-model");
config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text"));

GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory(
client,
alwaysOn
).create(null, "tag", "desc", true, config, null);

ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class);
List<Interaction> chatHistory = List
.of(
new Interaction(
"0",
Instant.now(),
"1",
"question",
"",
"answer",
"foo",
Collections.singletonMap("meta data", "some meta")
)
);
doAnswer(invocation -> {
((ActionListener<List<Interaction>>) invocation.getArguments()[2]).onResponse(chatHistory);
return null;
}).when(memoryClient).getInteractions(any(), anyInt(), any());
processor.setMemoryClient(memoryClient);

SearchRequest request = new SearchRequest();
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
extBuilder.setParams(null);
request.source(sourceBuilder);
sourceBuilder.ext(List.of(extBuilder));

int numHits = 10;
SearchHit[] hitsArray = new SearchHit[numHits];
for (int i = 0; i < numHits; i++) {
XContentBuilder sourceContent = JsonXContent
.contentBuilder()
.startObject()
.field("_id", String.valueOf(i))
.field("text", "passage" + i)
.field("title", "This is the title for document " + i)
.endObject();
hitsArray[i] = new SearchHit(i, "doc" + i, Map.of(), Map.of());
hitsArray[i].sourceRef(BytesReference.bytes(sourceContent));
}

SearchHits searchHits = new SearchHits(hitsArray, null, 1.0f);
SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0);
SearchResponse response = new SearchResponse(internal, null, 1, 1, 0, 1, null, null, null);

Llm llm = mock(Llm.class);
processor.setLlm(llm);

processor
.processResponseAsync(
request,
response,
null,
ActionListener.wrap(r -> { assertTrue(r instanceof GenerativeSearchResponse); }, e -> {})
);
}

public void testProcessResponseIllegalArgument() throws Exception {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("llm_model cannot be null.");
Expand Down

0 comments on commit 2988c87

Please sign in to comment.