From dda0cb72650ce515c08e5b1d020893da93113503 Mon Sep 17 00:00:00 2001 From: Anshu Agarwal Date: Mon, 19 Sep 2022 14:13:18 +0530 Subject: [PATCH] =?UTF-8?q?Weighted=20round-robin=20scheduling=20policy=20?= =?UTF-8?q?for=20shard=20coordination=20traffic=E2=80=A6=20(#4241)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Weighted round-robin scheduling policy for shard coordination traffic routing Signed-off-by: Anshu Agarwal --- CHANGELOG.md | 1 + .../org/opensearch/cluster/ClusterModule.java | 9 + .../opensearch/cluster/metadata/Metadata.java | 8 + .../metadata/WeightedRoutingMetadata.java | 167 +++++++++ .../routing/IndexShardRoutingTable.java | 134 ++++++++ .../cluster/routing/OperationRouting.java | 46 ++- .../cluster/routing/WeightedRoundRobin.java | 106 ++++++ .../cluster/routing/WeightedRouting.java | 75 +++++ .../common/settings/ClusterSettings.java | 1 + .../WeightedRoutingMetadataTests.java | 36 ++ .../routing/OperationRoutingTests.java | 318 ++++++++++++++++++ .../routing/WeightedRoundRobinTests.java | 151 +++++++++ .../structure/RoutingIteratorTests.java | 210 ++++++++++++ 13 files changed, 1254 insertions(+), 8 deletions(-) create mode 100644 server/src/main/java/org/opensearch/cluster/metadata/WeightedRoutingMetadata.java create mode 100644 server/src/main/java/org/opensearch/cluster/routing/WeightedRoundRobin.java create mode 100644 server/src/main/java/org/opensearch/cluster/routing/WeightedRouting.java create mode 100644 server/src/test/java/org/opensearch/cluster/metadata/WeightedRoutingMetadataTests.java create mode 100644 server/src/test/java/org/opensearch/cluster/routing/WeightedRoundRobinTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index c2e4fc2fe1aba..3c610c89eeaf3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Dependency updates (httpcore, mockito, slf4j, httpasyncclient, commons-codec) ([#4308](https://github.com/opensearch-project/OpenSearch/pull/4308)) - Use RemoteSegmentStoreDirectory instead of RemoteDirectory ([#4240](https://github.com/opensearch-project/OpenSearch/pull/4240)) - Plugin ZIP publication groupId value is configurable ([#4156](https://github.com/opensearch-project/OpenSearch/pull/4156)) +- Weighted round-robin scheduling policy for shard coordination traffic ([#4241](https://github.com/opensearch-project/OpenSearch/pull/4241)) - Add index specific setting for remote repository ([#4253](https://github.com/opensearch-project/OpenSearch/pull/4253)) - [Segment Replication] Update replicas to commit SegmentInfos instead of relying on SIS files from primary shards. ([#4402](https://github.com/opensearch-project/OpenSearch/pull/4402)) - [CCR] Add getHistoryOperationsFromTranslog method to fetch the history snapshot from translogs ([#3948](https://github.com/opensearch-project/OpenSearch/pull/3948)) diff --git a/server/src/main/java/org/opensearch/cluster/ClusterModule.java b/server/src/main/java/org/opensearch/cluster/ClusterModule.java index f8ba520e465e2..46552bb5d6a03 100644 --- a/server/src/main/java/org/opensearch/cluster/ClusterModule.java +++ b/server/src/main/java/org/opensearch/cluster/ClusterModule.java @@ -48,6 +48,7 @@ import org.opensearch.cluster.metadata.MetadataMappingService; import org.opensearch.cluster.metadata.MetadataUpdateSettingsService; import org.opensearch.cluster.metadata.RepositoriesMetadata; +import org.opensearch.cluster.metadata.WeightedRoutingMetadata; import org.opensearch.cluster.routing.DelayedAllocationService; import org.opensearch.cluster.routing.allocation.AllocationService; import org.opensearch.cluster.routing.allocation.ExistingShardsAllocator; @@ -191,6 +192,7 @@ public static List getNamedWriteables() { ComposableIndexTemplateMetadata::readDiffFrom ); registerMetadataCustom(entries, DataStreamMetadata.TYPE, DataStreamMetadata::new, DataStreamMetadata::readDiffFrom); + registerMetadataCustom(entries, WeightedRoutingMetadata.TYPE, WeightedRoutingMetadata::new, WeightedRoutingMetadata::readDiffFrom); // Task Status (not Diffable) entries.add(new Entry(Task.Status.class, PersistentTasksNodeService.Status.NAME, PersistentTasksNodeService.Status::new)); return entries; @@ -274,6 +276,13 @@ public static List getNamedXWriteables() { DataStreamMetadata::fromXContent ) ); + entries.add( + new NamedXContentRegistry.Entry( + Metadata.Custom.class, + new ParseField(WeightedRoutingMetadata.TYPE), + WeightedRoutingMetadata::fromXContent + ) + ); return entries; } diff --git a/server/src/main/java/org/opensearch/cluster/metadata/Metadata.java b/server/src/main/java/org/opensearch/cluster/metadata/Metadata.java index 5f7e98e9e1199..086865d2170c3 100644 --- a/server/src/main/java/org/opensearch/cluster/metadata/Metadata.java +++ b/server/src/main/java/org/opensearch/cluster/metadata/Metadata.java @@ -810,6 +810,14 @@ public IndexGraveyard indexGraveyard() { return custom(IndexGraveyard.TYPE); } + /** + * * + * @return The weighted routing metadata for search requests + */ + public WeightedRoutingMetadata weightedRoutingMetadata() { + return custom(WeightedRoutingMetadata.TYPE); + } + public T custom(String type) { return (T) customs.get(type); } diff --git a/server/src/main/java/org/opensearch/cluster/metadata/WeightedRoutingMetadata.java b/server/src/main/java/org/opensearch/cluster/metadata/WeightedRoutingMetadata.java new file mode 100644 index 0000000000000..27beb21f28f7c --- /dev/null +++ b/server/src/main/java/org/opensearch/cluster/metadata/WeightedRoutingMetadata.java @@ -0,0 +1,167 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.cluster.metadata; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchParseException; +import org.opensearch.Version; +import org.opensearch.cluster.AbstractNamedDiffable; +import org.opensearch.cluster.NamedDiff; +import org.opensearch.cluster.routing.WeightedRouting; +import org.opensearch.common.Strings; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.xcontent.ToXContent; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.Map; + +/** + * Contains metadata for weighted routing + * + * @opensearch.internal + */ +public class WeightedRoutingMetadata extends AbstractNamedDiffable implements Metadata.Custom { + private static final Logger logger = LogManager.getLogger(WeightedRoutingMetadata.class); + public static final String TYPE = "weighted_shard_routing"; + public static final String AWARENESS = "awareness"; + private WeightedRouting weightedRouting; + + public WeightedRouting getWeightedRouting() { + return weightedRouting; + } + + public WeightedRoutingMetadata setWeightedRouting(WeightedRouting weightedRouting) { + this.weightedRouting = weightedRouting; + return this; + } + + public WeightedRoutingMetadata(StreamInput in) throws IOException { + if (in.available() != 0) { + this.weightedRouting = new WeightedRouting(in); + } + } + + public WeightedRoutingMetadata(WeightedRouting weightedRouting) { + this.weightedRouting = weightedRouting; + } + + @Override + public EnumSet context() { + return Metadata.API_AND_GATEWAY; + } + + @Override + public String getWriteableName() { + return TYPE; + } + + @Override + public Version getMinimalSupportedVersion() { + return Version.V_2_4_0; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (weightedRouting != null) { + weightedRouting.writeTo(out); + } + } + + public static NamedDiff readDiffFrom(StreamInput in) throws IOException { + return readDiffFrom(Metadata.Custom.class, TYPE, in); + } + + public static WeightedRoutingMetadata fromXContent(XContentParser parser) throws IOException { + String attrKey = null; + Double attrValue; + String attributeName = null; + Map weights = new HashMap<>(); + WeightedRouting weightedRouting = null; + XContentParser.Token token; + // move to the first alias + parser.nextToken(); + String awarenessField = null; + + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + awarenessField = parser.currentName(); + if (parser.nextToken() != XContentParser.Token.START_OBJECT) { + throw new OpenSearchParseException( + "failed to parse weighted routing metadata [{}], expected " + "object", + awarenessField + ); + } + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + attributeName = parser.currentName(); + if (parser.nextToken() != XContentParser.Token.START_OBJECT) { + throw new OpenSearchParseException( + "failed to parse weighted routing metadata [{}], expected" + " object", + attributeName + ); + } + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + attrKey = parser.currentName(); + } else if (token == XContentParser.Token.VALUE_NUMBER) { + attrValue = Double.parseDouble(parser.text()); + weights.put(attrKey, attrValue); + } else { + throw new OpenSearchParseException( + "failed to parse weighted routing metadata attribute " + "[{}], unknown type", + attributeName + ); + } + } + } + } + } + weightedRouting = new WeightedRouting(attributeName, weights); + return new WeightedRoutingMetadata(weightedRouting); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + WeightedRoutingMetadata that = (WeightedRoutingMetadata) o; + return weightedRouting.equals(that.weightedRouting); + } + + @Override + public int hashCode() { + return weightedRouting.hashCode(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + toXContent(weightedRouting, builder); + return builder; + } + + public static void toXContent(WeightedRouting weightedRouting, XContentBuilder builder) throws IOException { + builder.startObject(AWARENESS); + builder.startObject(weightedRouting.attributeName()); + for (Map.Entry entry : weightedRouting.weights().entrySet()) { + builder.field(entry.getKey(), entry.getValue()); + } + builder.endObject(); + builder.endObject(); + } + + @Override + public String toString() { + return Strings.toString(this); + } +} diff --git a/server/src/main/java/org/opensearch/cluster/routing/IndexShardRoutingTable.java b/server/src/main/java/org/opensearch/cluster/routing/IndexShardRoutingTable.java index d4597f47d9a6c..9026e7068e9fe 100644 --- a/server/src/main/java/org/opensearch/cluster/routing/IndexShardRoutingTable.java +++ b/server/src/main/java/org/opensearch/cluster/routing/IndexShardRoutingTable.java @@ -85,6 +85,9 @@ public class IndexShardRoutingTable implements Iterable { private volatile Map activeShardsByAttributes = emptyMap(); private volatile Map initializingShardsByAttributes = emptyMap(); private final Object shardsByAttributeMutex = new Object(); + private final Object shardsByWeightMutex = new Object(); + private volatile Map> activeShardsByWeight = emptyMap(); + private volatile Map> initializingShardsByWeight = emptyMap(); /** * The initializing list, including ones that are initializing on a target node because of relocation. @@ -233,6 +236,10 @@ public List assignedShards() { return this.assignedShards; } + public Map> getActiveShardsByWeight() { + return activeShardsByWeight; + } + public ShardIterator shardsRandomIt() { return new PlainShardIterator(shardId, shuffler.shuffle(shards)); } @@ -292,6 +299,73 @@ public ShardIterator activeInitializingShardsRankedIt( return new PlainShardIterator(shardId, ordered); } + /** + * Returns an iterator over active and initializing shards, shards are ordered by weighted + * round-robin scheduling policy. + * + * @param weightedRouting entity + * @param nodes discovered nodes in the cluster + * @return an iterator over active and initializing shards, ordered by weighted round-robin + * scheduling policy. Making sure that initializing shards are the last to iterate through. + */ + public ShardIterator activeInitializingShardsWeightedIt(WeightedRouting weightedRouting, DiscoveryNodes nodes, double defaultWeight) { + final int seed = shuffler.nextSeed(); + List ordered = new ArrayList<>(); + List orderedActiveShards = getActiveShardsByWeight(weightedRouting, nodes, defaultWeight); + ordered.addAll(shuffler.shuffle(orderedActiveShards, seed)); + if (!allInitializingShards.isEmpty()) { + List orderedInitializingShards = getInitializingShardsByWeight(weightedRouting, nodes, defaultWeight); + ordered.addAll(orderedInitializingShards); + } + return new PlainShardIterator(shardId, ordered); + } + + /** + * Returns a list containing shard routings ordered using weighted round-robin scheduling. + */ + private List shardsOrderedByWeight( + List shards, + WeightedRouting weightedRouting, + DiscoveryNodes nodes, + double defaultWeight + ) { + WeightedRoundRobin weightedRoundRobin = new WeightedRoundRobin<>( + calculateShardWeight(shards, weightedRouting, nodes, defaultWeight) + ); + List> shardsOrderedbyWeight = weightedRoundRobin.orderEntities(); + List orderedShardRouting = new ArrayList<>(activeShards.size()); + if (shardsOrderedbyWeight != null) { + for (WeightedRoundRobin.Entity shardRouting : shardsOrderedbyWeight) { + orderedShardRouting.add(shardRouting.getTarget()); + } + } + return orderedShardRouting; + } + + /** + * Returns a list containing shard routing and associated weight. This function iterates through all the shards and + * uses weighted routing to find weight for the corresponding shard. This is fed to weighted round-robin scheduling + * to order shards by weight. + */ + private List> calculateShardWeight( + List shards, + WeightedRouting weightedRouting, + DiscoveryNodes nodes, + double defaultWeight + ) { + List> shardsWithWeights = new ArrayList<>(); + for (ShardRouting shard : shards) { + DiscoveryNode node = nodes.get(shard.currentNodeId()); + if (node != null) { + String attVal = node.getAttributes().get(weightedRouting.attributeName()); + // If weight for a zone is not defined, considering it as 1 by default + Double weight = weightedRouting.weights().getOrDefault(attVal, defaultWeight); + shardsWithWeights.add(new WeightedRoundRobin.Entity<>(weight, shard)); + } + } + return shardsWithWeights; + } + private static Set getAllNodeIds(final List shards) { final Set nodeIds = new HashSet<>(); for (ShardRouting shard : shards) { @@ -698,6 +772,66 @@ public int shardsMatchingPredicateCount(Predicate predicate) { return count; } + /** + * Key for WeightedRouting Shard Iterator + * + * @opensearch.internal + */ + public static class WeightedRoutingKey { + private final WeightedRouting weightedRouting; + + public WeightedRoutingKey(WeightedRouting weightedRouting) { + this.weightedRouting = weightedRouting; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + WeightedRoutingKey key = (WeightedRoutingKey) o; + if (!weightedRouting.equals(key.weightedRouting)) return false; + return true; + } + + @Override + public int hashCode() { + int result = weightedRouting.hashCode(); + return result; + } + } + + /** + * * + * Gets active shard routing from memory if available, else calculates and put it in memory. + */ + private List getActiveShardsByWeight(WeightedRouting weightedRouting, DiscoveryNodes nodes, double defaultWeight) { + WeightedRoutingKey key = new WeightedRoutingKey(weightedRouting); + List shardRoutings = activeShardsByWeight.get(key); + if (shardRoutings == null) { + synchronized (shardsByWeightMutex) { + shardRoutings = shardsOrderedByWeight(activeShards, weightedRouting, nodes, defaultWeight); + activeShardsByWeight = new MapBuilder().put(key, shardRoutings).immutableMap(); + } + } + return shardRoutings; + } + + /** + * * + * Gets initializing shard routing from memory if available, else calculates and put it in memory. + */ + private List getInitializingShardsByWeight(WeightedRouting weightedRouting, DiscoveryNodes nodes, double defaultWeight) { + WeightedRoutingKey key = new WeightedRoutingKey(weightedRouting); + List shardRoutings = initializingShardsByWeight.get(key); + if (shardRoutings == null) { + synchronized (shardsByWeightMutex) { + shardRoutings = shardsOrderedByWeight(activeShards, weightedRouting, nodes, defaultWeight); + initializingShardsByWeight = new MapBuilder().put(key, shardRoutings).immutableMap(); + } + } + return shardRoutings; + } + /** * Builder of an index shard routing table. * diff --git a/server/src/main/java/org/opensearch/cluster/routing/OperationRouting.java b/server/src/main/java/org/opensearch/cluster/routing/OperationRouting.java index 30f6408c19783..9026da667ccb0 100644 --- a/server/src/main/java/org/opensearch/cluster/routing/OperationRouting.java +++ b/server/src/main/java/org/opensearch/cluster/routing/OperationRouting.java @@ -34,6 +34,7 @@ import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.WeightedRoutingMetadata; import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.routing.allocation.decider.AwarenessAllocationDecider; import org.opensearch.common.Nullable; @@ -75,9 +76,17 @@ public class OperationRouting { Setting.Property.Dynamic, Setting.Property.NodeScope ); + public static final Setting WEIGHTED_ROUTING_DEFAULT_WEIGHT = Setting.doubleSetting( + "cluster.routing.weighted.default_weight", + 1.0, + 1.0, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); private volatile List awarenessAttributes; private volatile boolean useAdaptiveReplicaSelection; private volatile boolean ignoreAwarenessAttr; + private volatile double weightedRoutingDefaultWeight; public OperationRouting(Settings settings, ClusterSettings clusterSettings) { // whether to ignore awareness attributes when routing requests @@ -88,8 +97,10 @@ public OperationRouting(Settings settings, ClusterSettings clusterSettings) { this::setAwarenessAttributes ); this.useAdaptiveReplicaSelection = USE_ADAPTIVE_REPLICA_SELECTION_SETTING.get(settings); + this.weightedRoutingDefaultWeight = WEIGHTED_ROUTING_DEFAULT_WEIGHT.get(settings); clusterSettings.addSettingsUpdateConsumer(USE_ADAPTIVE_REPLICA_SELECTION_SETTING, this::setUseAdaptiveReplicaSelection); clusterSettings.addSettingsUpdateConsumer(IGNORE_AWARENESS_ATTRIBUTES_SETTING, this::setIgnoreAwarenessAttributes); + clusterSettings.addSettingsUpdateConsumer(WEIGHTED_ROUTING_DEFAULT_WEIGHT, this::setWeightedRoutingDefaultWeight); } void setUseAdaptiveReplicaSelection(boolean useAdaptiveReplicaSelection) { @@ -100,6 +111,10 @@ void setIgnoreAwarenessAttributes(boolean ignoreAwarenessAttributes) { this.ignoreAwarenessAttr = ignoreAwarenessAttributes; } + void setWeightedRoutingDefaultWeight(double weightedRoutingDefaultWeight) { + this.weightedRoutingDefaultWeight = weightedRoutingDefaultWeight; + } + public boolean isIgnoreAwarenessAttr() { return ignoreAwarenessAttr; } @@ -116,6 +131,10 @@ public boolean ignoreAwarenessAttributes() { return this.awarenessAttributes.isEmpty() || this.ignoreAwarenessAttr; } + public double getWeightedRoutingDefaultWeight() { + return this.weightedRoutingDefaultWeight; + } + public ShardIterator indexShards(ClusterState clusterState, String index, String id, @Nullable String routing) { return shards(clusterState, index, id, routing).shardsIt(); } @@ -133,7 +152,8 @@ public ShardIterator getShards( clusterState.nodes(), preference, null, - null + null, + clusterState.getMetadata().weightedRoutingMetadata() ); } @@ -145,7 +165,8 @@ public ShardIterator getShards(ClusterState clusterState, String index, int shar clusterState.nodes(), preference, null, - null + null, + clusterState.metadata().weightedRoutingMetadata() ); } @@ -175,7 +196,8 @@ public GroupShardsIterator searchShards( clusterState.nodes(), preference, collectorService, - nodeCounts + nodeCounts, + clusterState.metadata().weightedRoutingMetadata() ); if (iterator != null) { set.add(iterator); @@ -225,10 +247,11 @@ private ShardIterator preferenceActiveShardIterator( DiscoveryNodes nodes, @Nullable String preference, @Nullable ResponseCollectorService collectorService, - @Nullable Map nodeCounts + @Nullable Map nodeCounts, + @Nullable WeightedRoutingMetadata weightedRoutingMetadata ) { if (preference == null || preference.isEmpty()) { - return shardRoutings(indexShard, nodes, collectorService, nodeCounts); + return shardRoutings(indexShard, nodes, collectorService, nodeCounts, weightedRoutingMetadata); } if (preference.charAt(0) == '_') { Preference preferenceType = Preference.parse(preference); @@ -255,7 +278,7 @@ private ShardIterator preferenceActiveShardIterator( } // no more preference if (index == -1 || index == preference.length() - 1) { - return shardRoutings(indexShard, nodes, collectorService, nodeCounts); + return shardRoutings(indexShard, nodes, collectorService, nodeCounts, weightedRoutingMetadata); } else { // update the preference and continue preference = preference.substring(index + 1); @@ -298,9 +321,16 @@ private ShardIterator shardRoutings( IndexShardRoutingTable indexShard, DiscoveryNodes nodes, @Nullable ResponseCollectorService collectorService, - @Nullable Map nodeCounts + @Nullable Map nodeCounts, + @Nullable WeightedRoutingMetadata weightedRoutingMetadata ) { - if (ignoreAwarenessAttributes()) { + if (weightedRoutingMetadata != null) { + return indexShard.activeInitializingShardsWeightedIt( + weightedRoutingMetadata.getWeightedRouting(), + nodes, + getWeightedRoutingDefaultWeight() + ); + } else if (ignoreAwarenessAttributes()) { if (useAdaptiveReplicaSelection) { return indexShard.activeInitializingShardsRankedIt(collectorService, nodeCounts); } else { diff --git a/server/src/main/java/org/opensearch/cluster/routing/WeightedRoundRobin.java b/server/src/main/java/org/opensearch/cluster/routing/WeightedRoundRobin.java new file mode 100644 index 0000000000000..15d437db9c8ff --- /dev/null +++ b/server/src/main/java/org/opensearch/cluster/routing/WeightedRoundRobin.java @@ -0,0 +1,106 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.cluster.routing; + +import java.util.ArrayList; +import java.util.List; + +/** + * Weighted Round Robin Scheduling policy + * + */ +public class WeightedRoundRobin { + + private List> entities; + + public WeightedRoundRobin(List> entities) { + this.entities = entities; + } + + /** + * * + * @return list of entities that is ordered using weighted round-robin scheduling + * http://kb.linuxvirtualserver.org/wiki/Weighted_Round-Robin_Scheduling + */ + public List> orderEntities() { + int lastSelectedEntity = -1; + int size = entities.size(); + double currentWeight = 0; + List> orderedWeight = new ArrayList<>(); + if (size == 0) { + return null; + } + // Find maximum weight and greatest common divisor of weight across all entities + double maxWeight = 0; + double sumWeight = 0; + Double gcd = null; + for (WeightedRoundRobin.Entity entity : entities) { + maxWeight = Math.max(maxWeight, entity.getWeight()); + gcd = (gcd == null) ? entity.getWeight() : gcd(gcd, entity.getWeight()); + sumWeight += entity.getWeight() > 0 ? entity.getWeight() : 0; + } + int count = 0; + while (count < sumWeight) { + lastSelectedEntity = (lastSelectedEntity + 1) % size; + if (lastSelectedEntity == 0) { + currentWeight = currentWeight - gcd; + if (currentWeight <= 0) { + currentWeight = maxWeight; + if (currentWeight == 0) { + return orderedWeight; + } + } + } + if (entities.get(lastSelectedEntity).getWeight() >= currentWeight) { + orderedWeight.add(entities.get(lastSelectedEntity)); + count++; + } + } + return orderedWeight; + } + + /** + * Return greatest common divisor for two integers + * https://en.wikipedia.org/wiki/Greatest_common_divisor#Using_Euclid.27s_algorithm + * + * @param a first number + * @param b second number + * @return greatest common divisor + */ + private double gcd(double a, double b) { + return (b == 0) ? a : gcd(b, a % b); + } + + static final class Entity { + private double weight; + private T target; + + public Entity(double weight, T target) { + this.weight = weight; + this.target = target; + } + + public T getTarget() { + return this.target; + } + + public void setTarget(T target) { + this.target = target; + } + + public double getWeight() { + return this.weight; + } + + public void setWeight(double weight) { + this.weight = weight; + } + } + +} diff --git a/server/src/main/java/org/opensearch/cluster/routing/WeightedRouting.java b/server/src/main/java/org/opensearch/cluster/routing/WeightedRouting.java new file mode 100644 index 0000000000000..df2d8d595eaab --- /dev/null +++ b/server/src/main/java/org/opensearch/cluster/routing/WeightedRouting.java @@ -0,0 +1,75 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.cluster.routing; + +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.io.stream.Writeable; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +/** + * Entity for Weighted Round Robin weights + * + * @opensearch.internal + */ +public class WeightedRouting implements Writeable { + private String attributeName; + private Map weights; + + public WeightedRouting(String attributeName, Map weights) { + this.attributeName = attributeName; + this.weights = weights; + } + + public WeightedRouting(WeightedRouting weightedRouting) { + this.attributeName = weightedRouting.attributeName(); + this.weights = weightedRouting.weights; + } + + public WeightedRouting(StreamInput in) throws IOException { + attributeName = in.readString(); + weights = (Map) in.readGenericValue(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(attributeName); + out.writeGenericValue(weights); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + WeightedRouting that = (WeightedRouting) o; + if (!attributeName.equals(that.attributeName)) return false; + return weights.equals(that.weights); + } + + @Override + public int hashCode() { + return Objects.hash(attributeName, weights); + } + + @Override + public String toString() { + return "WeightedRouting{" + attributeName + "}{" + weights().toString() + "}"; + } + + public Map weights() { + return this.weights; + } + + public String attributeName() { + return this.attributeName; + } +} diff --git a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java index 971fb518ff1da..1665614c18496 100644 --- a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java +++ b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java @@ -529,6 +529,7 @@ public void apply(Settings value, Settings current, Settings previous) { Node.BREAKER_TYPE_KEY, OperationRouting.USE_ADAPTIVE_REPLICA_SELECTION_SETTING, OperationRouting.IGNORE_AWARENESS_ATTRIBUTES_SETTING, + OperationRouting.WEIGHTED_ROUTING_DEFAULT_WEIGHT, IndexGraveyard.SETTING_MAX_TOMBSTONES, PersistentTasksClusterService.CLUSTER_TASKS_ALLOCATION_RECHECK_INTERVAL_SETTING, EnableAssignmentDecider.CLUSTER_TASKS_ALLOCATION_ENABLE_SETTING, diff --git a/server/src/test/java/org/opensearch/cluster/metadata/WeightedRoutingMetadataTests.java b/server/src/test/java/org/opensearch/cluster/metadata/WeightedRoutingMetadataTests.java new file mode 100644 index 0000000000000..a0a9d2bd9586b --- /dev/null +++ b/server/src/test/java/org/opensearch/cluster/metadata/WeightedRoutingMetadataTests.java @@ -0,0 +1,36 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.cluster.metadata; + +import org.opensearch.cluster.routing.WeightedRouting; +import org.opensearch.common.xcontent.XContentParser; +import org.opensearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.Map; + +public class WeightedRoutingMetadataTests extends AbstractXContentTestCase { + @Override + protected WeightedRoutingMetadata createTestInstance() { + Map weights = Map.of("a", 1.0, "b", 1.0, "c", 0.0); + WeightedRouting weightedRouting = new WeightedRouting("zone", weights); + WeightedRoutingMetadata weightedRoutingMetadata = new WeightedRoutingMetadata(weightedRouting); + return weightedRoutingMetadata; + } + + @Override + protected WeightedRoutingMetadata doParseInstance(XContentParser parser) throws IOException { + return WeightedRoutingMetadata.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } +} diff --git a/server/src/test/java/org/opensearch/cluster/routing/OperationRoutingTests.java b/server/src/test/java/org/opensearch/cluster/routing/OperationRoutingTests.java index 8bf2b1626292a..87cab4a006a63 100644 --- a/server/src/test/java/org/opensearch/cluster/routing/OperationRoutingTests.java +++ b/server/src/test/java/org/opensearch/cluster/routing/OperationRoutingTests.java @@ -36,6 +36,7 @@ import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.metadata.WeightedRoutingMetadata; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.node.DiscoveryNodeRole; import org.opensearch.cluster.routing.allocation.decider.AwarenessAllocationDecider; @@ -759,6 +760,232 @@ public void testAdaptiveReplicaSelectionWithZoneAwarenessIgnored() throws Except terminate(threadPool); } + private ClusterState clusterStateForWeightedRouting(String[] indexNames, int numShards, int numReplicas) { + DiscoveryNode[] allNodes = setUpNodesForWeightedRouting(); + ClusterState state = ClusterStateCreationUtils.state(allNodes[0], allNodes[6], allNodes); + + Map> discoveryNodeMap = new HashMap<>(); + List nodesZoneA = new ArrayList<>(); + nodesZoneA.add(allNodes[0]); + nodesZoneA.add(allNodes[1]); + + List nodesZoneB = new ArrayList<>(); + nodesZoneB.add(allNodes[2]); + nodesZoneB.add(allNodes[3]); + + List nodesZoneC = new ArrayList<>(); + nodesZoneC.add(allNodes[4]); + nodesZoneC.add(allNodes[5]); + discoveryNodeMap.put("a", nodesZoneA); + discoveryNodeMap.put("b", nodesZoneB); + discoveryNodeMap.put("c", nodesZoneC); + + // Updating cluster state with node, index and shard details + state = updateStatetoTestWeightedRouting(indexNames, numShards, numReplicas, state, discoveryNodeMap); + + return state; + + } + + private ClusterState setWeightedRoutingWeights(ClusterState clusterState, Map weights) { + WeightedRouting weightedRouting = new WeightedRouting("zone", weights); + WeightedRoutingMetadata weightedRoutingMetadata = new WeightedRoutingMetadata(weightedRouting); + Metadata.Builder metadataBuilder = Metadata.builder(clusterState.metadata()); + metadataBuilder.putCustom(WeightedRoutingMetadata.TYPE, weightedRoutingMetadata); + clusterState = ClusterState.builder(clusterState).metadata(metadataBuilder).build(); + return clusterState; + } + + public void testWeightedOperationRouting() throws Exception { + final int numIndices = 2; + final int numShards = 3; + final int numReplicas = 2; + // setting up indices + final String[] indexNames = new String[numIndices]; + for (int i = 0; i < numIndices; i++) { + indexNames[i] = "test" + i; + } + ClusterService clusterService = null; + TestThreadPool threadPool = null; + try { + ClusterState state = clusterStateForWeightedRouting(indexNames, numShards, numReplicas); + + Settings setting = Settings.builder().put("cluster.routing.allocation.awareness.attributes", "zone").build(); + + threadPool = new TestThreadPool("testThatOnlyNodesSupport"); + clusterService = ClusterServiceUtils.createClusterService(threadPool); + + OperationRouting opRouting = new OperationRouting( + setting, + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); + assertTrue(opRouting.ignoreAwarenessAttributes()); + Set selectedNodes = new HashSet<>(); + ResponseCollectorService collector = new ResponseCollectorService(clusterService); + Map outstandingRequests = new HashMap<>(); + + // Setting up weights for weighted round-robin in cluster state + Map weights = Map.of("a", 1.0, "b", 1.0, "c", 0.0); + state = setWeightedRoutingWeights(state, weights); + + ClusterState.Builder builder = ClusterState.builder(state); + ClusterServiceUtils.setState(clusterService, builder); + + // search shards call + GroupShardsIterator groupIterator = opRouting.searchShards( + state, + indexNames, + null, + null, + collector, + outstandingRequests + + ); + + for (ShardIterator it : groupIterator) { + List shardRoutings = Collections.singletonList(it.nextOrNull()); + for (ShardRouting shardRouting : shardRoutings) { + selectedNodes.add(shardRouting.currentNodeId()); + } + } + // tests no shards are assigned to nodes in zone c + for (String nodeID : selectedNodes) { + // No shards are assigned to nodes in zone c since its weight is 0 + assertFalse(nodeID.contains("c")); + } + + selectedNodes = new HashSet<>(); + setting = Settings.builder().put("cluster.routing.allocation.awareness.attributes", "zone").build(); + + // Updating weighted round robin weights in cluster state + weights = Map.of("a", 1.0, "b", 0.0, "c", 1.0); + state = setWeightedRoutingWeights(state, weights); + + builder = ClusterState.builder(state); + ClusterServiceUtils.setState(clusterService, builder); + + opRouting = new OperationRouting(setting, new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS)); + + // search shards call + groupIterator = opRouting.searchShards(state, indexNames, null, null, collector, outstandingRequests); + + for (ShardIterator it : groupIterator) { + List shardRoutings = Collections.singletonList(it.nextOrNull()); + for (ShardRouting shardRouting : shardRoutings) { + selectedNodes.add(shardRouting.currentNodeId()); + } + } + // tests that no shards are assigned to zone with weight zero + for (String nodeID : selectedNodes) { + // No shards are assigned to nodes in zone b since its weight is 0 + assertFalse(nodeID.contains("b")); + } + } finally { + IOUtils.close(clusterService); + terminate(threadPool); + } + } + + public void testWeightedOperationRoutingWeightUndefinedForOneZone() throws Exception { + final int numIndices = 2; + final int numShards = 3; + final int numReplicas = 2; + // setting up indices + final String[] indexNames = new String[numIndices]; + for (int i = 0; i < numIndices; i++) { + indexNames[i] = "test" + i; + } + + ClusterService clusterService = null; + TestThreadPool threadPool = null; + try { + ClusterState state = clusterStateForWeightedRouting(indexNames, numShards, numReplicas); + + Settings setting = Settings.builder().put("cluster.routing.allocation.awareness.attributes", "zone").build(); + + threadPool = new TestThreadPool("testThatOnlyNodesSupport"); + clusterService = ClusterServiceUtils.createClusterService(threadPool); + + OperationRouting opRouting = new OperationRouting( + setting, + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); + assertTrue(opRouting.ignoreAwarenessAttributes()); + Set selectedNodes = new HashSet<>(); + ResponseCollectorService collector = new ResponseCollectorService(clusterService); + Map outstandingRequests = new HashMap<>(); + + // Setting up weights for weighted round-robin in cluster state, weight for nodes in zone b is not set + Map weights = Map.of("a", 1.0, "c", 0.0); + state = setWeightedRoutingWeights(state, weights); + ClusterServiceUtils.setState(clusterService, ClusterState.builder(state)); + + // search shards call + GroupShardsIterator groupIterator = opRouting.searchShards( + state, + indexNames, + null, + null, + collector, + outstandingRequests + + ); + + for (ShardIterator it : groupIterator) { + List shardRoutings = Collections.singletonList(it.nextOrNull()); + for (ShardRouting shardRouting : shardRoutings) { + selectedNodes.add(shardRouting.currentNodeId()); + } + } + boolean weighAwayNodesInUndefinedZone = true; + // tests no shards are assigned to nodes in zone c + // tests shards are assigned to nodes in zone b + for (String nodeID : selectedNodes) { + // shard from nodes in zone c is not selected since its weight is 0 + assertFalse(nodeID.contains("c")); + if (nodeID.contains("b")) { + weighAwayNodesInUndefinedZone = false; + } + } + assertFalse(weighAwayNodesInUndefinedZone); + + selectedNodes = new HashSet<>(); + setting = Settings.builder().put("cluster.routing.allocation.awareness.attributes", "zone").build(); + + // Updating weighted round robin weights in cluster state + weights = Map.of("a", 0.0, "b", 1.0); + + state = setWeightedRoutingWeights(state, weights); + ClusterServiceUtils.setState(clusterService, ClusterState.builder(state)); + + opRouting = new OperationRouting(setting, new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS)); + + // search shards call + groupIterator = opRouting.searchShards(state, indexNames, null, null, collector, outstandingRequests); + + for (ShardIterator it : groupIterator) { + List shardRoutings = Collections.singletonList(it.nextOrNull()); + for (ShardRouting shardRouting : shardRoutings) { + selectedNodes.add(shardRouting.currentNodeId()); + } + } + // tests that no shards are assigned to zone with weight zero + // tests shards are assigned to nodes in zone c + weighAwayNodesInUndefinedZone = true; + for (String nodeID : selectedNodes) { + // shard from nodes in zone a is not selected since its weight is 0 + assertFalse(nodeID.contains("a")); + if (nodeID.contains("c")) { + weighAwayNodesInUndefinedZone = false; + } + } + assertFalse(weighAwayNodesInUndefinedZone); + } finally { + IOUtils.close(clusterService); + terminate(threadPool); + } + } + private DiscoveryNode[] setupNodes() { // Sets up two data nodes in zone-a and one data node in zone-b List zones = Arrays.asList("a", "a", "b"); @@ -785,6 +1012,32 @@ private DiscoveryNode[] setupNodes() { return allNodes; } + private DiscoveryNode[] setUpNodesForWeightedRouting() { + List zones = Arrays.asList("a", "a", "b", "b", "c", "c"); + DiscoveryNode[] allNodes = new DiscoveryNode[7]; + int i = 0; + for (String zone : zones) { + DiscoveryNode node = new DiscoveryNode( + "node_" + zone + "_" + i, + buildNewFakeTransportAddress(), + singletonMap("zone", zone), + Collections.singleton(DiscoveryNodeRole.DATA_ROLE), + Version.CURRENT + ); + allNodes[i++] = node; + } + + DiscoveryNode clusterManager = new DiscoveryNode( + "cluster-manager", + buildNewFakeTransportAddress(), + Collections.emptyMap(), + Collections.singleton(DiscoveryNodeRole.CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + allNodes[i] = clusterManager; + return allNodes; + } + public void testAllocationAwarenessDeprecation() { OperationRouting routing = new OperationRouting( Settings.builder() @@ -841,4 +1094,69 @@ private ClusterState updateStatetoTestARS( clusterState.routingTable(routingTableBuilder.build()); return clusterState.build(); } + + private ClusterState updateStatetoTestWeightedRouting( + String[] indices, + int numberOfShards, + int numberOfReplicas, + ClusterState state, + Map> discoveryNodeMap + ) { + RoutingTable.Builder routingTableBuilder = RoutingTable.builder(); + Metadata.Builder metadataBuilder = Metadata.builder(); + ClusterState.Builder clusterState = ClusterState.builder(state); + List nodesZoneA = discoveryNodeMap.get("a"); + List nodesZoneB = discoveryNodeMap.get("b"); + List nodesZoneC = discoveryNodeMap.get("c"); + for (String index : indices) { + IndexMetadata indexMetadata = IndexMetadata.builder(index) + .settings( + Settings.builder() + .put(SETTING_VERSION_CREATED, Version.CURRENT) + .put(SETTING_NUMBER_OF_SHARDS, numberOfShards) + .put(SETTING_NUMBER_OF_REPLICAS, numberOfReplicas) + .put(SETTING_CREATION_DATE, System.currentTimeMillis()) + ) + .build(); + metadataBuilder.put(indexMetadata, false).generateClusterUuidIfNeeded(); + IndexRoutingTable.Builder indexRoutingTableBuilder = IndexRoutingTable.builder(indexMetadata.getIndex()); + for (int i = 0; i < numberOfShards; i++) { + final ShardId shardId = new ShardId(index, "_na_", i); + IndexShardRoutingTable.Builder indexShardRoutingBuilder = new IndexShardRoutingTable.Builder(shardId); + // Assign all the primary shards on nodes in zone-a (node_a0 or node_a1) + indexShardRoutingBuilder.addShard( + TestShardRouting.newShardRouting( + index, + i, + nodesZoneA.get(randomInt(nodesZoneA.size() - 1)).getId(), + null, + true, + ShardRoutingState.STARTED + ) + ); + for (int replica = 0; replica < numberOfReplicas; replica++) { + // Assign all the replicas on nodes in zone-b (node_b2) + String nodeId = ""; + if (replica == 0) { + nodeId = nodesZoneB.get(randomInt(nodesZoneB.size() - 1)).getId(); + } else { + nodeId = nodesZoneC.get(randomInt(nodesZoneC.size() - 1)).getId(); + } + indexShardRoutingBuilder.addShard( + TestShardRouting.newShardRouting(index, i, nodeId, null, false, ShardRoutingState.STARTED) + ); + } + indexRoutingTableBuilder.addIndexShard(indexShardRoutingBuilder.build()); + } + routingTableBuilder.add(indexRoutingTableBuilder.build()); + } + // add weighted routing weights in metadata + Map weights = Map.of("a", 1.0, "b", 1.0, "c", 0.0); + WeightedRouting weightedRouting = new WeightedRouting("zone", weights); + WeightedRoutingMetadata weightedRoutingMetadata = new WeightedRoutingMetadata(weightedRouting); + metadataBuilder.putCustom(WeightedRoutingMetadata.TYPE, weightedRoutingMetadata); + clusterState.metadata(metadataBuilder); + clusterState.routingTable(routingTableBuilder.build()); + return clusterState.build(); + } } diff --git a/server/src/test/java/org/opensearch/cluster/routing/WeightedRoundRobinTests.java b/server/src/test/java/org/opensearch/cluster/routing/WeightedRoundRobinTests.java new file mode 100644 index 0000000000000..5f62d30486e86 --- /dev/null +++ b/server/src/test/java/org/opensearch/cluster/routing/WeightedRoundRobinTests.java @@ -0,0 +1,151 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.cluster.routing; + +import org.opensearch.test.OpenSearchTestCase; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class WeightedRoundRobinTests extends OpenSearchTestCase { + + public void testWeightedRoundRobinOrder() { + // weights set as A:4, B:3, C:2 + List> entity = new ArrayList>(); + entity.add(new WeightedRoundRobin.Entity<>(4, "A")); + entity.add(new WeightedRoundRobin.Entity<>(3, "B")); + entity.add(new WeightedRoundRobin.Entity<>(2, "C")); + WeightedRoundRobin weightedRoundRobin = new WeightedRoundRobin(entity); + List> orderedEntities = weightedRoundRobin.orderEntities(); + List expectedOrdering = Arrays.asList("A", "A", "B", "A", "B", "C", "A", "B", "C"); + List actualOrdering = new ArrayList<>(); + for (WeightedRoundRobin.Entity en : orderedEntities) { + actualOrdering.add(en.getTarget()); + } + assertEquals(expectedOrdering, actualOrdering); + + // weights set as A:1, B:1, C:0 + entity = new ArrayList>(); + entity.add(new WeightedRoundRobin.Entity<>(1, "A")); + entity.add(new WeightedRoundRobin.Entity<>(1, "B")); + entity.add(new WeightedRoundRobin.Entity<>(0, "C")); + weightedRoundRobin = new WeightedRoundRobin(entity); + orderedEntities = weightedRoundRobin.orderEntities(); + expectedOrdering = Arrays.asList("A", "B"); + actualOrdering = new ArrayList<>(); + for (WeightedRoundRobin.Entity en : orderedEntities) { + actualOrdering.add(en.getTarget()); + } + assertEquals(expectedOrdering, actualOrdering); + + // weights set as A:0, B:0, C:0 + entity = new ArrayList>(); + entity.add(new WeightedRoundRobin.Entity<>(0, "A")); + entity.add(new WeightedRoundRobin.Entity<>(0, "B")); + entity.add(new WeightedRoundRobin.Entity<>(0, "C")); + weightedRoundRobin = new WeightedRoundRobin(entity); + orderedEntities = weightedRoundRobin.orderEntities(); + expectedOrdering = Arrays.asList(); + actualOrdering = new ArrayList<>(); + for (WeightedRoundRobin.Entity en : orderedEntities) { + actualOrdering.add(en.getTarget()); + } + assertEquals(expectedOrdering, actualOrdering); + + // weights set as A:-1, B:0, C:1 + entity = new ArrayList>(); + entity.add(new WeightedRoundRobin.Entity<>(-1, "A")); + entity.add(new WeightedRoundRobin.Entity<>(0, "B")); + entity.add(new WeightedRoundRobin.Entity<>(1, "C")); + weightedRoundRobin = new WeightedRoundRobin(entity); + orderedEntities = weightedRoundRobin.orderEntities(); + expectedOrdering = Arrays.asList("C"); + actualOrdering = new ArrayList<>(); + for (WeightedRoundRobin.Entity en : orderedEntities) { + actualOrdering.add(en.getTarget()); + } + assertEquals(expectedOrdering, actualOrdering); + + // weights set as A:-1, B:3, C:0, D:10 + entity = new ArrayList>(); + entity.add(new WeightedRoundRobin.Entity<>(-1, "A")); + entity.add(new WeightedRoundRobin.Entity<>(3, "B")); + entity.add(new WeightedRoundRobin.Entity<>(0, "C")); + entity.add(new WeightedRoundRobin.Entity<>(10, "D")); + weightedRoundRobin = new WeightedRoundRobin(entity); + orderedEntities = weightedRoundRobin.orderEntities(); + expectedOrdering = Arrays.asList("B", "D", "B", "D", "B", "D", "D", "D", "D", "D", "D", "D", "D"); + actualOrdering = new ArrayList<>(); + for (WeightedRoundRobin.Entity en : orderedEntities) { + actualOrdering.add(en.getTarget()); + } + assertEquals(expectedOrdering, actualOrdering); + + // weights set as A:-1, B:3, C:0, D:10000 + entity = new ArrayList>(); + entity.add(new WeightedRoundRobin.Entity<>(-1, "A")); + entity.add(new WeightedRoundRobin.Entity<>(3, "B")); + entity.add(new WeightedRoundRobin.Entity<>(0, "C")); + entity.add(new WeightedRoundRobin.Entity<>(10000, "D")); + weightedRoundRobin = new WeightedRoundRobin(entity); + orderedEntities = weightedRoundRobin.orderEntities(); + assertEquals(10003, orderedEntities.size()); + // Count of D's + int countD = 0; + // Count of B's + int countB = 0; + for (WeightedRoundRobin.Entity en : orderedEntities) { + if (en.getTarget().equals("D")) { + countD++; + } else if (en.getTarget().equals("B")) { + countB++; + } + } + assertEquals(3, countB); + assertEquals(10000, countD); + + // weights set C:0 + entity = new ArrayList>(); + entity.add(new WeightedRoundRobin.Entity<>(0, "C")); + weightedRoundRobin = new WeightedRoundRobin(entity); + orderedEntities = weightedRoundRobin.orderEntities(); + expectedOrdering = Arrays.asList(); + actualOrdering = new ArrayList<>(); + for (WeightedRoundRobin.Entity en : orderedEntities) { + actualOrdering.add(en.getTarget()); + } + assertEquals(expectedOrdering, actualOrdering); + + // weights set C:1 + entity = new ArrayList>(); + entity.add(new WeightedRoundRobin.Entity<>(1, "C")); + weightedRoundRobin = new WeightedRoundRobin(entity); + orderedEntities = weightedRoundRobin.orderEntities(); + expectedOrdering = Arrays.asList("C"); + actualOrdering = new ArrayList<>(); + for (WeightedRoundRobin.Entity en : orderedEntities) { + actualOrdering.add(en.getTarget()); + } + assertEquals(expectedOrdering, actualOrdering); + + // weights set C:2 + entity = new ArrayList>(); + entity.add(new WeightedRoundRobin.Entity<>(2, "C")); + weightedRoundRobin = new WeightedRoundRobin(entity); + orderedEntities = weightedRoundRobin.orderEntities(); + expectedOrdering = Arrays.asList("C", "C"); + actualOrdering = new ArrayList<>(); + for (WeightedRoundRobin.Entity en : orderedEntities) { + actualOrdering.add(en.getTarget()); + } + assertEquals(expectedOrdering, actualOrdering); + } + +} diff --git a/server/src/test/java/org/opensearch/cluster/structure/RoutingIteratorTests.java b/server/src/test/java/org/opensearch/cluster/structure/RoutingIteratorTests.java index 68ad47fa1bbc9..8f5aa1b764551 100644 --- a/server/src/test/java/org/opensearch/cluster/structure/RoutingIteratorTests.java +++ b/server/src/test/java/org/opensearch/cluster/structure/RoutingIteratorTests.java @@ -40,6 +40,7 @@ import org.opensearch.cluster.metadata.Metadata; import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.routing.GroupShardsIterator; +import org.opensearch.cluster.routing.IndexShardRoutingTable; import org.opensearch.cluster.routing.OperationRouting; import org.opensearch.cluster.routing.PlainShardIterator; import org.opensearch.cluster.routing.RotationShardShuffler; @@ -48,11 +49,15 @@ import org.opensearch.cluster.routing.ShardRouting; import org.opensearch.cluster.routing.ShardShuffler; import org.opensearch.cluster.routing.ShardsIterator; +import org.opensearch.cluster.routing.WeightedRouting; import org.opensearch.cluster.routing.allocation.AllocationService; import org.opensearch.cluster.routing.allocation.decider.ClusterRebalanceAllocationDecider; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.index.shard.ShardId; +import org.opensearch.test.ClusterServiceUtils; +import org.opensearch.threadpool.TestThreadPool; import java.util.Arrays; import java.util.Collections; @@ -497,4 +502,209 @@ public void testReplicaShardPreferenceIters() throws Exception { } } + public void testWeightedRoutingWithDifferentWeights() { + TestThreadPool threadPool = null; + try { + Settings.Builder settings = Settings.builder() + .put("cluster.routing.allocation.node_concurrent_recoveries", 10) + .put("cluster.routing.allocation.awareness.attributes", "zone"); + AllocationService strategy = createAllocationService(settings.build()); + + Metadata metadata = Metadata.builder() + .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(2)) + .build(); + + RoutingTable routingTable = RoutingTable.builder().addAsNew(metadata.index("test")).build(); + + ClusterState clusterState = ClusterState.builder(ClusterName.CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY)) + .metadata(metadata) + .routingTable(routingTable) + .build(); + + threadPool = new TestThreadPool("testThatOnlyNodesSupport"); + ClusterService clusterService = ClusterServiceUtils.createClusterService(threadPool); + + Map node1Attributes = new HashMap<>(); + node1Attributes.put("zone", "zone1"); + Map node2Attributes = new HashMap<>(); + node2Attributes.put("zone", "zone2"); + Map node3Attributes = new HashMap<>(); + node3Attributes.put("zone", "zone3"); + clusterState = ClusterState.builder(clusterState) + .nodes( + DiscoveryNodes.builder() + .add(newNode("node1", unmodifiableMap(node1Attributes))) + .add(newNode("node2", unmodifiableMap(node2Attributes))) + .add(newNode("node3", unmodifiableMap(node3Attributes))) + .localNodeId("node1") + ) + .build(); + clusterState = strategy.reroute(clusterState, "reroute"); + + clusterState = startInitializingShardsAndReroute(strategy, clusterState); + clusterState = startInitializingShardsAndReroute(strategy, clusterState); + + Map weights = Map.of("zone1", 1.0, "zone2", 1.0, "zone3", 0.0); + WeightedRouting weightedRouting = new WeightedRouting("zone", weights); + + ShardIterator shardIterator = clusterState.routingTable() + .index("test") + .shard(0) + .activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1); + + assertEquals(2, shardIterator.size()); + ShardRouting shardRouting; + while (shardIterator.remaining() > 0) { + shardRouting = shardIterator.nextOrNull(); + assertNotNull(shardRouting); + assertFalse(Arrays.asList("node3").contains(shardRouting.currentNodeId())); + } + + weights = Map.of("zone1", 1.0, "zone2", 1.0, "zone3", 1.0); + weightedRouting = new WeightedRouting("zone", weights); + shardIterator = clusterState.routingTable() + .index("test") + .shard(0) + .activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1); + assertEquals(3, shardIterator.size()); + + weights = Map.of("zone1", -1.0, "zone2", 0.0, "zone3", 1.0); + weightedRouting = new WeightedRouting("zone", weights); + shardIterator = clusterState.routingTable() + .index("test") + .shard(0) + .activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1); + assertEquals(1, shardIterator.size()); + while (shardIterator.remaining() > 0) { + shardRouting = shardIterator.nextOrNull(); + assertNotNull(shardRouting); + assertFalse(Arrays.asList("node2", "node1").contains(shardRouting.currentNodeId())); + } + + weights = Map.of("zone1", 0.0, "zone2", 0.0, "zone3", 0.0); + weightedRouting = new WeightedRouting("zone", weights); + shardIterator = clusterState.routingTable() + .index("test") + .shard(0) + .activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1); + assertEquals(0, shardIterator.size()); + } finally { + terminate(threadPool); + } + } + + public void testWeightedRoutingInMemoryStore() { + TestThreadPool threadPool = null; + try { + Settings.Builder settings = Settings.builder() + .put("cluster.routing.allocation.node_concurrent_recoveries", 10) + .put("cluster.routing.allocation.awareness.attributes", "zone"); + AllocationService strategy = createAllocationService(settings.build()); + + Metadata metadata = Metadata.builder() + .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(2)) + .build(); + + RoutingTable routingTable = RoutingTable.builder().addAsNew(metadata.index("test")).build(); + + ClusterState clusterState = ClusterState.builder(ClusterName.CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY)) + .metadata(metadata) + .routingTable(routingTable) + .build(); + + threadPool = new TestThreadPool("testThatOnlyNodesSupport"); + ClusterService clusterService = ClusterServiceUtils.createClusterService(threadPool); + + Map node1Attributes = new HashMap<>(); + node1Attributes.put("zone", "zone1"); + Map node2Attributes = new HashMap<>(); + node2Attributes.put("zone", "zone2"); + Map node3Attributes = new HashMap<>(); + node3Attributes.put("zone", "zone3"); + clusterState = ClusterState.builder(clusterState) + .nodes( + DiscoveryNodes.builder() + .add(newNode("node1", unmodifiableMap(node1Attributes))) + .add(newNode("node2", unmodifiableMap(node2Attributes))) + .add(newNode("node3", unmodifiableMap(node3Attributes))) + .localNodeId("node1") + ) + .build(); + clusterState = strategy.reroute(clusterState, "reroute"); + + clusterState = startInitializingShardsAndReroute(strategy, clusterState); + clusterState = startInitializingShardsAndReroute(strategy, clusterState); + + Map weights = Map.of("zone1", 1.0, "zone2", 1.0, "zone3", 0.0); + WeightedRouting weightedRouting = new WeightedRouting("zone", weights); + + IndexShardRoutingTable indexShardRoutingTable = clusterState.routingTable().index("test").shard(0); + + assertNull( + indexShardRoutingTable.getActiveShardsByWeight().get(new IndexShardRoutingTable.WeightedRoutingKey(weightedRouting)) + ); + ShardIterator shardIterator = clusterState.routingTable() + .index("test") + .shard(0) + .activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1); + assertEquals(2, shardIterator.size()); + ShardRouting shardRouting; + while (shardIterator.remaining() > 0) { + shardRouting = shardIterator.nextOrNull(); + assertNotNull(shardRouting); + assertFalse(Arrays.asList("node3").contains(shardRouting.currentNodeId())); + } + + // Make iterator call with same WeightedRouting instance + assertNotNull( + indexShardRoutingTable.getActiveShardsByWeight().get(new IndexShardRoutingTable.WeightedRoutingKey(weightedRouting)) + ); + shardIterator = clusterState.routingTable() + .index("test") + .shard(0) + .activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1); + assertEquals(2, shardIterator.size()); + while (shardIterator.remaining() > 0) { + shardRouting = shardIterator.nextOrNull(); + assertNotNull(shardRouting); + assertFalse(Arrays.asList("node3").contains(shardRouting.currentNodeId())); + } + + // Make iterator call with new instance of WeightedRouting but same weights + Map weights1 = Map.of("zone1", 1.0, "zone2", 1.0, "zone3", 0.0); + weightedRouting = new WeightedRouting("zone", weights1); + assertNotNull( + indexShardRoutingTable.getActiveShardsByWeight().get(new IndexShardRoutingTable.WeightedRoutingKey(weightedRouting)) + ); + shardIterator = clusterState.routingTable() + .index("test") + .shard(0) + .activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1); + assertEquals(2, shardIterator.size()); + while (shardIterator.remaining() > 0) { + shardRouting = shardIterator.nextOrNull(); + assertNotNull(shardRouting); + assertFalse(Arrays.asList("node3").contains(shardRouting.currentNodeId())); + } + + // Make iterator call with different weights + Map weights2 = Map.of("zone1", 1.0, "zone2", 0.0, "zone3", 1.0); + weightedRouting = new WeightedRouting("zone", weights2); + assertNull( + indexShardRoutingTable.getActiveShardsByWeight().get(new IndexShardRoutingTable.WeightedRoutingKey(weightedRouting)) + ); + shardIterator = clusterState.routingTable() + .index("test") + .shard(0) + .activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1); + assertEquals(2, shardIterator.size()); + while (shardIterator.remaining() > 0) { + shardRouting = shardIterator.nextOrNull(); + assertNotNull(shardRouting); + assertFalse(Arrays.asList("node2").contains(shardRouting.currentNodeId())); + } + } finally { + terminate(threadPool); + } + } }