diff --git a/server/src/main/java/org/opensearch/rest/RequestLimitSettings.java b/server/src/main/java/org/opensearch/rest/RequestLimitSettings.java index 727113def43a6..8cc4199d302b9 100644 --- a/server/src/main/java/org/opensearch/rest/RequestLimitSettings.java +++ b/server/src/main/java/org/opensearch/rest/RequestLimitSettings.java @@ -94,17 +94,28 @@ public boolean isCircuitLimitBreached(final ClusterState clusterState, final Blo switch (actionToCheck) { case CAT_INDICES: if (catIndicesLimit <= 0) return false; - int indicesCount = getTotalIndices(clusterState); + int indicesCount = chainWalk(() -> clusterState.getMetadata().getIndices().size(), 0); if (indicesCount > catIndicesLimit) return true; break; case CAT_SHARDS: if (catShardsLimit <= 0) return false; - int totalShards = getTotalShards(clusterState); - if (totalShards > catShardsLimit) return true; + final RoutingTable routingTable = clusterState.getRoutingTable(); + final Map indexRoutingTableMap = routingTable.getIndicesRouting(); + int totalShards = 0; + for (final Map.Entry entry : indexRoutingTableMap.entrySet()) { + for (final Map.Entry indexShardRoutingTableEntry : entry.getValue() + .getShards() + .entrySet()) { + totalShards += indexShardRoutingTableEntry.getValue().getShards().size(); + // Fail fast if catShardsLimit value is breached and avoid unnecessary computation. + if (totalShards > catShardsLimit) return true; + } + } break; case CAT_SEGMENTS: if (catSegmentsLimit <= 0) return false; - if (getTotalIndices(clusterState) > catSegmentsLimit) return true; + int indicesCountForCatSegment = chainWalk(() -> clusterState.getRoutingTable().getIndicesRouting().size(), 0); + if (indicesCountForCatSegment > catSegmentsLimit) return true; break; } return false; @@ -122,22 +133,6 @@ private void setCatSegmentsLimitSetting(final int catSegmentsLimit) { this.catSegmentsLimit = catSegmentsLimit; } - private static int getTotalIndices(final ClusterState clusterState) { - return chainWalk(() -> clusterState.getMetadata().getIndices().size(), 0); - } - - private static int getTotalShards(final ClusterState clusterState) { - final RoutingTable routingTable = clusterState.getRoutingTable(); - final Map indexRoutingTableMap = routingTable.getIndicesRouting(); - int totalShards = 0; - for (final Map.Entry entry : indexRoutingTableMap.entrySet()) { - for (final Map.Entry indexShardRoutingTableEntry : entry.getValue().getShards().entrySet()) { - totalShards += indexShardRoutingTableEntry.getValue().getShards().size(); - } - } - return totalShards; - } - // TODO: Evaluate if we can move this to common util. private static T chainWalk(Supplier supplier, T defaultValue) { try {