Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify shardsWithState #91991

Merged
merged 8 commits into from
Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
import java.util.Objects;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Stream;

import static java.util.stream.Collectors.toCollection;

/**
* A {@link RoutingNode} represents a cluster node associated with a single {@link DiscoveryNode} including all shards
Expand Down Expand Up @@ -211,29 +214,11 @@ void remove(ShardRouting shard) {

/**
* Determine the number of shards with a specific state
* @param states set of states which should be counted
* @param state which should be counted
* @return number of shards
*/
public int numberOfShardsWithState(ShardRoutingState... states) {
if (states.length == 1) {
if (states[0] == ShardRoutingState.INITIALIZING) {
return initializingShards.size();
} else if (states[0] == ShardRoutingState.RELOCATING) {
return relocatingShards.size();
} else if (states[0] == ShardRoutingState.STARTED) {
return startedShards.size();
}
}

int count = 0;
for (ShardRouting shardEntry : this) {
for (ShardRoutingState state : states) {
if (shardEntry.state() == state) {
count++;
}
}
}
return count;
public int numberOfShardsWithState(ShardRoutingState state) {
return internalGetShardsWithState(state).size();
}

/**
Expand All @@ -242,20 +227,7 @@ public int numberOfShardsWithState(ShardRoutingState... states) {
* @return List of shards
*/
public List<ShardRouting> shardsWithState(ShardRoutingState state) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICT this method is only used in two places in prod code, both of which could reasonably accept a Stream<ShardRouting> instead and avoid the need to copy into a new list.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has plenty of unit test usages that needs to be updated as well. Do you mind if I do it in a followup pr (possibly with moving some of this methods to a test code)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good

if (state == ShardRoutingState.INITIALIZING) {
return new ArrayList<>(initializingShards);
} else if (state == ShardRoutingState.RELOCATING) {
return new ArrayList<>(relocatingShards);
} else if (state == ShardRoutingState.STARTED) {
return new ArrayList<>(startedShards);
}
List<ShardRouting> shards = new ArrayList<>();
for (ShardRouting shardEntry : this) {
if (shardEntry.state() == state) {
shards.add(shardEntry);
}
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch was effectively unreachable.

return shards;
return new ArrayList<>(internalGetShardsWithState(state));
}

private static final ShardRouting[] EMPTY_SHARD_ROUTING_ARRAY = new ShardRouting[0];
Expand All @@ -279,49 +251,28 @@ public ShardRouting[] started() {
* @return a list of shards
*/
public List<ShardRouting> shardsWithState(String index, ShardRoutingState... states) {
List<ShardRouting> shards = new ArrayList<>();

if (states.length == 1) {
if (states[0] == ShardRoutingState.INITIALIZING) {
for (ShardRouting shardEntry : initializingShards) {
if (shardEntry.getIndexName().equals(index) == false) {
continue;
}
shards.add(shardEntry);
}
return shards;
} else if (states[0] == ShardRoutingState.RELOCATING) {
for (ShardRouting shardEntry : relocatingShards) {
if (shardEntry.getIndexName().equals(index) == false) {
continue;
}
shards.add(shardEntry);
}
return shards;
} else if (states[0] == ShardRoutingState.STARTED) {
for (ShardRouting shardEntry : startedShards) {
if (shardEntry.getIndexName().equals(index) == false) {
continue;
}
shards.add(shardEntry);
}
return shards;
}
}
return Stream.of(states).flatMap(state -> shardsWithState(index, state).stream()).collect(toCollection(ArrayList::new));
}

for (ShardRouting shardEntry : this) {
if (shardEntry.getIndexName().equals(index) == false) {
continue;
}
for (ShardRoutingState state : states) {
if (shardEntry.state() == state) {
shards.add(shardEntry);
}
public List<ShardRouting> shardsWithState(String index, ShardRoutingState state) {
var shards = new ArrayList<ShardRouting>();
for (ShardRouting shardEntry : internalGetShardsWithState(state)) {
if (shardEntry.getIndexName().equals(index)) {
shards.add(shardEntry);
}
}
return shards;
}

private LinkedHashSet<ShardRouting> internalGetShardsWithState(ShardRoutingState state) {
return switch (state) {
case UNASSIGNED -> throw new IllegalArgumentException("Unassigned shards are not linked to a routing node");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was previously returning empty list or 0 but I guess it is better to be explicit that routing node can not be used to obtain unassigned shards.

case INITIALIZING -> initializingShards;
case STARTED -> startedShards;
case RELOCATING -> relocatingShards;
};
}

/**
* The number of shards on this node that will not be eventually relocated.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ public void testRemove() {
}

public void testNumberOfShardsWithState() {
assertThat(routingNode.numberOfShardsWithState(ShardRoutingState.INITIALIZING, ShardRoutingState.STARTED), equalTo(2));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to test since this method is removed

assertThat(routingNode.numberOfShardsWithState(ShardRoutingState.STARTED), equalTo(1));
assertThat(routingNode.numberOfShardsWithState(ShardRoutingState.RELOCATING), equalTo(1));
assertThat(routingNode.numberOfShardsWithState(ShardRoutingState.INITIALIZING), equalTo(1));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import static org.elasticsearch.cluster.routing.ShardRoutingState.UNASSIGNED;
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;

Expand Down Expand Up @@ -507,10 +506,14 @@ public void testRebalanceFailure() {
RoutingNodes routingNodes = clusterState.getRoutingNodes();

assertThat(clusterState.routingTable().index("test").size(), equalTo(2));
assertThat(routingNodes.node("node1").numberOfShardsWithState(STARTED, RELOCATING), equalTo(2));
assertThat(routingNodes.node("node1").numberOfShardsWithState(STARTED), lessThan(3));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and following lessThan(3) assertions are removed as they are covered by assertion above

assertThat(routingNodes.node("node2").numberOfShardsWithState(STARTED, RELOCATING), equalTo(2));
assertThat(routingNodes.node("node2").numberOfShardsWithState(STARTED), lessThan(3));
assertThat(
routingNodes.node("node1").numberOfShardsWithState(STARTED) + routingNodes.node("node1").numberOfShardsWithState(RELOCATING),
equalTo(2)
);
assertThat(
routingNodes.node("node2").numberOfShardsWithState(STARTED) + routingNodes.node("node2").numberOfShardsWithState(RELOCATING),
equalTo(2)
);
assertThat(routingNodes.node("node3").numberOfShardsWithState(INITIALIZING), equalTo(1));

logger.info("Fail the shards on node 3");
Expand All @@ -521,10 +524,14 @@ public void testRebalanceFailure() {
routingNodes = clusterState.getRoutingNodes();

assertThat(clusterState.routingTable().index("test").size(), equalTo(2));
assertThat(routingNodes.node("node1").numberOfShardsWithState(STARTED, RELOCATING), equalTo(2));
assertThat(routingNodes.node("node1").numberOfShardsWithState(STARTED), lessThan(3));
assertThat(routingNodes.node("node2").numberOfShardsWithState(STARTED, RELOCATING), equalTo(2));
assertThat(routingNodes.node("node2").numberOfShardsWithState(STARTED), lessThan(3));
assertThat(
routingNodes.node("node1").numberOfShardsWithState(STARTED) + routingNodes.node("node1").numberOfShardsWithState(RELOCATING),
equalTo(2)
);
assertThat(
routingNodes.node("node2").numberOfShardsWithState(STARTED) + routingNodes.node("node2").numberOfShardsWithState(RELOCATING),
equalTo(2)
);

if (strategy.isBalancedShardsAllocator()) {
assertThat(routingNodes.node("node3").numberOfShardsWithState(INITIALIZING), equalTo(1));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,16 @@ public void testSingleIndexFirstStartPrimaryThenBackups() {
routingNodes = clusterState.getRoutingNodes();

assertThat(clusterState.routingTable().index("test").size(), equalTo(10));
assertThat(routingNodes.node("node1").numberOfShardsWithState(STARTED, RELOCATING), equalTo(10));
assertThat(routingNodes.node("node1").numberOfShardsWithState(STARTED), lessThan(10));
assertThat(routingNodes.node("node2").numberOfShardsWithState(STARTED, RELOCATING), equalTo(10));
assertThat(
routingNodes.node("node1").numberOfShardsWithState(STARTED) + routingNodes.node("node1").numberOfShardsWithState(RELOCATING),
equalTo(10)
);
assertThat(routingNodes.node("node2").numberOfShardsWithState(STARTED), lessThan(10));
assertThat(
routingNodes.node("node2").numberOfShardsWithState(STARTED) + routingNodes.node("node2").numberOfShardsWithState(RELOCATING),
equalTo(10)
);
assertThat(routingNodes.node("node3").numberOfShardsWithState(INITIALIZING), equalTo(6));

logger.info("Start the shards on node 3");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,29 @@ private RoutingNodesHelper() {}

public static List<ShardRouting> shardsWithState(RoutingNodes routingNodes, ShardRoutingState state) {
List<ShardRouting> shards = new ArrayList<>();
for (RoutingNode routingNode : routingNodes) {
shards.addAll(routingNode.shardsWithState(state));
}
if (state == ShardRoutingState.UNASSIGNED) {
routingNodes.unassigned().forEach(shards::add);
} else {
for (RoutingNode routingNode : routingNodes) {
shards.addAll(routingNode.shardsWithState(state));
}
}
return shards;
}

public static List<ShardRouting> shardsWithState(RoutingNodes routingNodes, String index, ShardRoutingState... state) {
public static List<ShardRouting> shardsWithState(RoutingNodes routingNodes, String index, ShardRoutingState... states) {
List<ShardRouting> shards = new ArrayList<>();
for (RoutingNode routingNode : routingNodes) {
shards.addAll(routingNode.shardsWithState(index, state));
}
for (ShardRoutingState s : state) {
if (s == ShardRoutingState.UNASSIGNED) {
for (ShardRoutingState state : states) {
if (state == ShardRoutingState.UNASSIGNED) {
for (ShardRouting unassignedShard : routingNodes.unassigned()) {
if (unassignedShard.index().getName().equals(index)) {
shards.add(unassignedShard);
}
}
break;
} else {
for (RoutingNode routingNode : routingNodes) {
shards.addAll(routingNode.shardsWithState(index, state));
}
}
}
return shards;
Expand All @@ -64,7 +65,6 @@ public static RoutingNode routingNode(String nodeId, DiscoveryNode node, ShardRo
for (ShardRouting shardRouting : shards) {
routingNode.add(shardRouting);
}

return routingNode;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ public void testClusterChangedWatchAliasChanged() throws Exception {
boolean emptyShards = randomBoolean();

if (emptyShards) {
when(routingNode.shardsWithState(eq(newActiveWatchIndex), any())).thenReturn(Collections.emptyList());
when(routingNode.shardsWithState(eq(newActiveWatchIndex), any(ShardRoutingState[].class))).thenReturn(Collections.emptyList());
} else {
Index index = new Index(newActiveWatchIndex, "uuid");
ShardId shardId = new ShardId(index, 0);
Expand Down