diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/DiscoveryNodeHelper.java b/plugin/src/main/java/org/opensearch/ml/cluster/DiscoveryNodeHelper.java index 73084df8b7..5476e3f520 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/DiscoveryNodeHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/DiscoveryNodeHelper.java @@ -68,7 +68,7 @@ public String[] getEligibleNodeIds(FunctionName functionName) { public DiscoveryNode[] getEligibleNodes(FunctionName functionName) { ClusterState state = this.clusterService.state(); - final List eligibleNodes = new ArrayList<>(); + final Set eligibleNodes = new HashSet<>(); for (DiscoveryNode node : state.nodes()) { if (excludedNodeNames != null && excludedNodeNames.contains(node.getName())) { continue; @@ -88,7 +88,7 @@ public DiscoveryNode[] getEligibleNodes(FunctionName functionName) { return eligibleNodes.toArray(new DiscoveryNode[0]); } - private void getEligibleNodes(Set allowedNodeRoles, List eligibleNodes, DiscoveryNode node) { + private void getEligibleNodes(Set allowedNodeRoles, Set eligibleNodes, DiscoveryNode node) { if (allowedNodeRoles.contains("data") && isEligibleDataNode(node)) { eligibleNodes.add(node); } @@ -110,21 +110,21 @@ public String[] filterEligibleNodes(FunctionName functionName, String[] nodeIds) continue; } if (functionName == FunctionName.REMOTE) {// remote model - getEligibleNodes(remoteModelEligibleNodeRoles, eligibleNodes, node); + getEligibleNodeIds(remoteModelEligibleNodeRoles, eligibleNodes, node); } else { // local model if (onlyRunOnMLNode) { if (MLNodeUtils.isMLNode(node)) { eligibleNodes.add(node.getId()); } } else { - getEligibleNodes(localModelEligibleNodeRoles, eligibleNodes, node); + getEligibleNodeIds(localModelEligibleNodeRoles, eligibleNodes, node); } } } return eligibleNodes.toArray(new String[0]); } - private void getEligibleNodes(Set allowedNodeRoles, Set eligibleNodes, DiscoveryNode node) { + private void getEligibleNodeIds(Set allowedNodeRoles, Set eligibleNodes, DiscoveryNode node) { if (allowedNodeRoles.contains("data") && isEligibleDataNode(node)) { eligibleNodes.add(node.getId()); } diff --git a/plugin/src/test/java/org/opensearch/ml/cluster/DiscoveryNodeHelperTests.java b/plugin/src/test/java/org/opensearch/ml/cluster/DiscoveryNodeHelperTests.java index 602db3220c..84df4ccb3e 100644 --- a/plugin/src/test/java/org/opensearch/ml/cluster/DiscoveryNodeHelperTests.java +++ b/plugin/src/test/java/org/opensearch/ml/cluster/DiscoveryNodeHelperTests.java @@ -13,6 +13,7 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ONLY_RUN_ON_ML_NODE; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_MODEL_ELIGIBLE_NODE_ROLES; +import static org.opensearch.ml.utils.TestHelper.ALL_ROLES; import static org.opensearch.ml.utils.TestHelper.ML_ROLE; import static org.opensearch.ml.utils.TestHelper.clusterSetting; @@ -52,6 +53,8 @@ public class DiscoveryNodeHelperTests extends OpenSearchTestCase { private final String mlNode1Name = "mlNodeName1"; private final String mlNode2Id = "mlNode2"; private final String mlNode2Name = "mlNodeName2"; + private final String allRoleNodeId = "allRoleNode"; + private final String allRoleNodeName = "allRoleNodeName"; private final String clusterName = "multi-node-cluster"; @Mock @@ -65,6 +68,7 @@ public class DiscoveryNodeHelperTests extends OpenSearchTestCase { private DiscoveryNode warmDataNode1; private DiscoveryNode mlNode1; private DiscoveryNode mlNode2; + private DiscoveryNode allRoleNode; private ClusterState clusterState; private String nonExistingNodeName; @@ -122,6 +126,14 @@ public void setup() throws IOException { ImmutableSet.of(ML_ROLE), Version.CURRENT ); + allRoleNode = new DiscoveryNode( + allRoleNodeName, + allRoleNodeId, + buildNewFakeTransportAddress(), + emptyMap(), + ALL_ROLES, + Version.CURRENT + ); DiscoveryNodes nodes = DiscoveryNodes .builder() @@ -131,6 +143,7 @@ public void setup() throws IOException { .add(warmDataNode1) .add(mlNode1) .add(mlNode2) + .add(allRoleNode) .build(); clusterState = new ClusterState(new ClusterName(clusterName), 123l, "111111", null, null, nodes, null, Map.of(), 0, false); @@ -158,23 +171,35 @@ private void mockSettings(boolean onlyRunOnMLNode, String excludedNodeName) { public void testGetEligibleNodes_MLNode_RemoteModel() { DiscoveryNode[] eligibleNodes = discoveryNodeHelper.getEligibleNodes(FunctionName.REMOTE); - assertEquals(4, eligibleNodes.length); + assertEquals(5, eligibleNodes.length); Set nodeIds = new HashSet<>(); nodeIds.addAll(Arrays.asList(eligibleNodes).stream().map(n -> n.getId()).collect(Collectors.toList())); assertTrue(nodeIds.contains(mlNode1.getId())); assertTrue(nodeIds.contains(mlNode2.getId())); assertTrue(nodeIds.contains(dataNode1.getId())); assertTrue(nodeIds.contains(dataNode2.getId())); + assertTrue(nodeIds.contains(allRoleNode.getId())); assertFalse(nodeIds.contains(warmDataNode1.getId())); } public void testGetEligibleNodes_MLNode_LocalModel() { DiscoveryNode[] eligibleNodes = discoveryNodeHelper.getEligibleNodes(FunctionName.TEXT_EMBEDDING); - assertEquals(2, eligibleNodes.length); + assertEquals(3, eligibleNodes.length); Set nodeIds = new HashSet<>(); nodeIds.addAll(Arrays.asList(eligibleNodes).stream().map(n -> n.getId()).collect(Collectors.toList())); assertTrue(nodeIds.contains(mlNode1.getId())); assertTrue(nodeIds.contains(mlNode2.getId())); + assertTrue(nodeIds.contains(allRoleNode.getId())); + } + + public void testGetEligibleNodes_MLNode_DataModel() { + DiscoveryNode[] eligibleNodes = discoveryNodeHelper.getEligibleNodes(FunctionName.TEXT_EMBEDDING); + assertEquals(3, eligibleNodes.length); + Set nodeIds = new HashSet<>(); + nodeIds.addAll(Arrays.asList(eligibleNodes).stream().map(n -> n.getId()).collect(Collectors.toList())); + assertTrue(nodeIds.contains(mlNode1.getId())); + assertTrue(nodeIds.contains(mlNode2.getId())); + assertTrue(nodeIds.contains(allRoleNode.getId())); } public void testGetEligibleNodes_DataNode() { @@ -186,17 +211,25 @@ public void testGetEligibleNodes_DataNode() { DiscoveryNode[] eligibleNodes = discoveryNodeHelper.getEligibleNodes(FunctionName.REMOTE); assertEquals(2, eligibleNodes.length); - assertEquals(dataNode1.getName(), eligibleNodes[0].getName()); - assertEquals(dataNode2.getName(), eligibleNodes[1].getName()); + Set nodeNames = new HashSet<>(); + nodeNames.add("dataNodeName1"); + nodeNames.add("dataNodeName2"); + assertTrue(nodeNames.contains(eligibleNodes[0].getName())); + assertTrue(nodeNames.contains(eligibleNodes[1].getName())); } public void testGetEligibleNodes_MLNode_Excluded() { mockSettings(false, mlNode1.getName() + "," + mlNode2.getName()); DiscoveryNodeHelper discoveryNodeHelper = new DiscoveryNodeHelper(clusterService, settings); DiscoveryNode[] eligibleNodes = discoveryNodeHelper.getEligibleNodes(FunctionName.TEXT_EMBEDDING); - assertEquals(2, eligibleNodes.length); - assertEquals(dataNode1.getName(), eligibleNodes[0].getName()); - assertEquals(dataNode2.getName(), eligibleNodes[1].getName()); + assertEquals(3, eligibleNodes.length); + Set nodeNames = new HashSet<>(); + nodeNames.add("dataNodeName1"); + nodeNames.add("dataNodeName2"); + nodeNames.add("allRoleNodeName"); + assertTrue(nodeNames.contains(eligibleNodes[0].getName())); + assertTrue(nodeNames.contains(eligibleNodes[1].getName())); + assertTrue(nodeNames.contains(eligibleNodes[2].getName())); } public void testFilterEligibleNodes_Null() { @@ -241,7 +274,7 @@ public void testFilterEligibleNodes_BothMLAndDataNodes() { public void testGetAllNodeIds() { String[] allNodeIds = discoveryNodeHelper.getAllNodeIds(); - assertEquals(6, allNodeIds.length); + assertEquals(7, allNodeIds.length); } public void testGetNodes() { diff --git a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java index 7eaf2e8db9..ecd97c233a 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java @@ -10,6 +10,9 @@ import static org.junit.Assert.assertNotNull; import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; import static org.opensearch.cluster.node.DiscoveryNodeRole.DATA_ROLE; +import static org.opensearch.cluster.node.DiscoveryNodeRole.INGEST_ROLE; +import static org.opensearch.cluster.node.DiscoveryNodeRole.REMOTE_CLUSTER_CLIENT_ROLE; +import static org.opensearch.cluster.node.DiscoveryNodeRole.SEARCH_ROLE; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_AGENT_ID; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM; @@ -21,12 +24,15 @@ import java.io.InputStreamReader; import java.net.InetAddress; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.SortedSet; +import java.util.TreeSet; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -92,6 +98,11 @@ public Setting legacySetting() { } }; + public static SortedSet ALL_ROLES = Collections + .unmodifiableSortedSet( + new TreeSet<>(Arrays.asList(DATA_ROLE, INGEST_ROLE, CLUSTER_MANAGER_ROLE, REMOTE_CLUSTER_CLIENT_ROLE, SEARCH_ROLE, ML_ROLE)) + ); + public static XContentParser parser(String xc) throws IOException { return parser(xc, true); }