diff --git a/server/src/main/java/org/opensearch/cluster/routing/RoutingNode.java b/server/src/main/java/org/opensearch/cluster/routing/RoutingNode.java index 1b1a9394ff306..bf8288885b4a4 100644 --- a/server/src/main/java/org/opensearch/cluster/routing/RoutingNode.java +++ b/server/src/main/java/org/opensearch/cluster/routing/RoutingNode.java @@ -49,6 +49,7 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; +import java.util.stream.Stream; import java.util.stream.StreamSupport; /** @@ -84,7 +85,7 @@ public ShardRouting get(ShardId shardId) { return this.shardTuple.v2().get(shardId); } - public ShardRouting add(ShardRouting shardRouting) { + public ShardRouting put(ShardRouting shardRouting) { return put(shardRouting.shardId(), shardRouting); } @@ -114,22 +115,10 @@ public ShardRouting remove(ShardId shardId) { @Override public Iterator iterator() { - final Iterator primaryIterator = Collections.unmodifiableCollection(this.shardTuple.v1().values()).iterator(); - final Iterator replicaIterator = Collections.unmodifiableCollection(this.shardTuple.v2().values()).iterator(); - return new Iterator() { - @Override - public boolean hasNext() { - return primaryIterator.hasNext() || replicaIterator.hasNext(); - } - - @Override - public ShardRouting next() { - if (primaryIterator.hasNext()) { - return primaryIterator.next(); - } - return replicaIterator.next(); - } - }; + return Stream.concat( + Collections.unmodifiableCollection(this.shardTuple.v1().values()).stream(), + Collections.unmodifiableCollection(this.shardTuple.v2().values()).stream() + ).iterator(); } } @@ -217,7 +206,7 @@ public int size() { */ void add(ShardRouting shard) { assert invariant(); - if (shards.add(shard) != null) { + if (shards.put(shard) != null) { throw new IllegalStateException( "Trying to add a shard " + shard.shardId() diff --git a/server/src/test/java/org/opensearch/cluster/routing/MovePrimaryFirstTests.java b/server/src/test/java/org/opensearch/cluster/routing/MovePrimaryFirstTests.java index c484d0eb759fc..fba6f1d48930b 100644 --- a/server/src/test/java/org/opensearch/cluster/routing/MovePrimaryFirstTests.java +++ b/server/src/test/java/org/opensearch/cluster/routing/MovePrimaryFirstTests.java @@ -12,11 +12,15 @@ import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsRequest; import org.opensearch.cluster.ClusterStateListener; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.test.InternalTestCluster; import org.opensearch.test.OpenSearchIntegTestCase; +import java.util.ArrayList; import java.util.Iterator; +import java.util.List; import java.util.concurrent.CountDownLatch; +import java.util.stream.Stream; import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked; @@ -83,19 +87,25 @@ public void testClusterGreenAfterPartialRelocation() throws InterruptedException final ClusterStateListener listener = event -> { if (event.routingTableChanged()) { final RoutingNodes routingNodes = event.state().getRoutingNodes(); - int startedz2n1 = 0; - int startedz2n2 = 0; + int startedCount = 0; + List initz2n1 = new ArrayList<>(), initz2n2 = new ArrayList<>(); for (Iterator it = routingNodes.iterator(); it.hasNext();) { RoutingNode routingNode = it.next(); final String nodeName = routingNode.node().getName(); if (nodeName.equals(z2n1)) { - startedz2n1 = routingNode.numberOfShardsWithState(ShardRoutingState.STARTED); + startedCount += routingNode.numberOfShardsWithState(ShardRoutingState.STARTED); + initz2n1 = routingNode.shardsWithState(ShardRoutingState.INITIALIZING); } else if (nodeName.equals(z2n2)) { - startedz2n2 = routingNode.numberOfShardsWithState(ShardRoutingState.STARTED); + startedCount += routingNode.numberOfShardsWithState(ShardRoutingState.STARTED); + initz2n2 = routingNode.shardsWithState(ShardRoutingState.INITIALIZING); } } - if (startedz2n1 >= primaryShardCount / 2 && startedz2n2 >= primaryShardCount / 2) { - primaryMoveLatch.countDown(); + if (!Stream.concat(initz2n1.stream(), initz2n2.stream()).anyMatch(s -> s.primary())) { + // All primaries are relocated before 60% of total shards are started on new nodes + final int totalShardCount = primaryShardCount * 2; + if (primaryShardCount <= startedCount && startedCount <= 3 * totalShardCount / 5) { + primaryMoveLatch.countDown(); + } } } }; @@ -112,6 +122,6 @@ public void testClusterGreenAfterPartialRelocation() throws InterruptedException internalCluster().stopRandomNode(InternalTestCluster.nameFilter(z1n1)); internalCluster().stopRandomNode(InternalTestCluster.nameFilter(z1n2)); } catch (Exception e) {} - ensureGreen(); + ensureGreen(TimeValue.timeValueSeconds(60)); } }