From 644d9085bc2fc8a09e5cbc2a82dfbcaec69d830d Mon Sep 17 00:00:00 2001 From: Aman Khare Date: Tue, 5 Mar 2024 23:39:52 +0530 Subject: [PATCH] Add unit test for ShardBatchCache implementation Signed-off-by: Aman Khare --- .../gateway/AsyncShardBatchFetch.java | 26 +- .../opensearch/gateway/ShardBatchCache.java | 29 +-- .../org/opensearch/gateway/BatchTestUtil.java | 25 ++ .../gateway/ShardBatchCacheTests.java | 231 ++++++++++++++++++ 4 files changed, 278 insertions(+), 33 deletions(-) create mode 100644 server/src/test/java/org/opensearch/gateway/BatchTestUtil.java create mode 100644 server/src/test/java/org/opensearch/gateway/ShardBatchCacheTests.java diff --git a/server/src/main/java/org/opensearch/gateway/AsyncShardBatchFetch.java b/server/src/main/java/org/opensearch/gateway/AsyncShardBatchFetch.java index 7d9dcd909e3df..b787bae5e5057 100644 --- a/server/src/main/java/org/opensearch/gateway/AsyncShardBatchFetch.java +++ b/server/src/main/java/org/opensearch/gateway/AsyncShardBatchFetch.java @@ -16,6 +16,7 @@ import org.opensearch.core.index.shard.ShardId; import org.opensearch.indices.store.ShardAttributes; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Set; @@ -35,6 +36,9 @@ */ public abstract class AsyncShardBatchFetch extends AsyncShardFetch { + private final Consumer removeShardFromBatch; + private final List failedShards; + @SuppressWarnings("unchecked") AsyncShardBatchFetch( Logger logger, @@ -49,6 +53,8 @@ public abstract class AsyncShardBatchFetch handleFailedShard ) { super(logger, type, shardAttributesMap, action, batchId); + this.removeShardFromBatch = handleFailedShard; + this.failedShards = new ArrayList<>(); this.cache = new ShardBatchCache<>( logger, type, @@ -71,7 +77,6 @@ public abstract class AsyncShardBatchFetch fetchData(DiscoveryNodes nodes, Map> ignoreNodes) { - List failedShards = cleanUpFailedShards(); if (failedShards.isEmpty() == false) { // trigger a reroute if there are any shards failed, to make sure they're picked up in next run logger.trace("triggering another reroute for failed shards in {}", reroutingKey); @@ -81,16 +86,15 @@ public synchronized FetchResult fetchData(DiscoveryNodes nodes, Map cleanUpFailedShards() { - List failedShards = cache.getFailedShards(); - if (failedShards != null && failedShards.isEmpty() == false) { - shardAttributesMap.keySet().removeIf(failedShards::contains); - } - return failedShards; + private void cleanUpFailedShard(ShardId shardId) { + shardAttributesMap.remove(shardId); + removeShardFromBatch.accept(shardId); + failedShards.add(shardId); } /** @@ -100,7 +104,7 @@ private List cleanUpFailedShards() { * @param shardId shardId to be removed from the batch. */ public void clearShard(ShardId shardId) { - this.shardAttributesMap.remove(shardId); - this.cache.clearShardCache(shardId); + shardAttributesMap.remove(shardId); + cache.deleteData(shardId); } } diff --git a/server/src/main/java/org/opensearch/gateway/ShardBatchCache.java b/server/src/main/java/org/opensearch/gateway/ShardBatchCache.java index 77880092374bd..336a5b5c94a42 100644 --- a/server/src/main/java/org/opensearch/gateway/ShardBatchCache.java +++ b/server/src/main/java/org/opensearch/gateway/ShardBatchCache.java @@ -51,7 +51,6 @@ public class ShardBatchCache arrayToShardId; private final Function> shardsBatchDataGetter; private final Supplier emptyResponseBuilder; - private final Set failedShards; private final Consumer handleFailedShard; public ShardBatchCache( @@ -72,7 +71,6 @@ public ShardBatchCache( this.responseConstructor = responseConstructor; this.shardsBatchDataGetter = shardsBatchDataGetter; this.emptyResponseBuilder = emptyResponseBuilder; - failedShards = new HashSet<>(); cache = new HashMap<>(); shardIdToArray = new HashMap<>(); arrayToShardId = new HashMap<>(); @@ -86,7 +84,7 @@ public ShardBatchCache( } @Override - public void clearShardCache(ShardId shardId) { + public void deleteData(ShardId shardId) { if (shardIdToArray.containsKey(shardId)) { Integer shardIdIndex = shardIdToArray.remove(shardId); for (String nodeId : cache.keySet()) { @@ -128,7 +126,7 @@ public void initData(DiscoveryNode node) { public void putData(DiscoveryNode node, T response) { NodeEntry nodeEntry = cache.get(node.getId()); Map batchResponse = shardsBatchDataGetter.apply(response); - failedShards.addAll(filterFailedShards(batchResponse)); + filterFailedShards(batchResponse); nodeEntry.doneFetching(batchResponse, shardIdToArray); } @@ -136,11 +134,8 @@ public void putData(DiscoveryNode node, T response) { * Return the shard for which we got unhandled exceptions. * * @param batchResponse response from one node for the batch. - * @return List of failed shards. */ - private List filterFailedShards(Map batchResponse) { - logger.trace("filtering failed shards"); - List failedShards = new ArrayList<>(); + private void filterFailedShards(Map batchResponse) { for (Iterator it = batchResponse.keySet().iterator(); it.hasNext();) { ShardId shardId = it.next(); if (batchResponse.get(shardId) != null) { @@ -149,11 +144,9 @@ private List filterFailedShards(Map batchResponse) { // the batch Exception shardException = batchResponse.get(shardId).getException(); // if the request got rejected or timed out, we need to try it again next time... - if (shardException instanceof OpenSearchRejectedExecutionException - || shardException instanceof ReceiveTimeoutTransportException - || shardException instanceof OpenSearchTimeoutException) { - logger.trace("got unhandled retryable exception for shard {} {}", shardId.toString(), shardException.toString()); - failedShards.add(shardId); + if (retryableException(shardException)) { + logger.trace("got unhandled retryable exception for shard {} {}", shardId.toString(), + shardException.toString()); handleFailedShard.accept(shardId); // remove this failed entry. So, while storing the data, we don't need to re-process it. it.remove(); @@ -161,7 +154,6 @@ private List filterFailedShards(Map batchResponse) { } } } - return failedShards; } @Override @@ -169,13 +161,6 @@ public T getData(DiscoveryNode node) { return this.responseConstructor.apply(node, getBatchData(cache.get(node.getId()))); } - @Override - public List getFailedShards() { - List defectedShards = List.copyOf(failedShards); - failedShards.clear(); - return defectedShards; - } - private HashMap getBatchData(NodeEntry nodeEntry) { V[] nodeShardEntries = nodeEntry.getData(); boolean[] emptyResponses = nodeEntry.getEmptyShardResponse(); @@ -197,7 +182,7 @@ private void fillShardIdKeys(Set shardIds) { } this.shardIdToArray.keySet().removeIf(shardId -> { if (!shardIds.contains(shardId)) { - clearShardCache(shardId); + deleteData(shardId); return true; } else { return false; diff --git a/server/src/test/java/org/opensearch/gateway/BatchTestUtil.java b/server/src/test/java/org/opensearch/gateway/BatchTestUtil.java new file mode 100644 index 0000000000000..69f0cfeeb2c7d --- /dev/null +++ b/server/src/test/java/org/opensearch/gateway/BatchTestUtil.java @@ -0,0 +1,25 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.gateway; + +import org.opensearch.core.index.shard.ShardId; + +import java.util.ArrayList; +import java.util.List; + +public class BatchTestUtil { + public static List setUpShards(int numberOfShards) { + List shards = new ArrayList<>(); + for (int shardNumber = 0; shardNumber < numberOfShards; shardNumber++) { + ShardId shardId = new ShardId("test", "_na_", shardNumber); + shards.add(shardId); + } + return shards; + } +} diff --git a/server/src/test/java/org/opensearch/gateway/ShardBatchCacheTests.java b/server/src/test/java/org/opensearch/gateway/ShardBatchCacheTests.java new file mode 100644 index 0000000000000..598de3d98f9b1 --- /dev/null +++ b/server/src/test/java/org/opensearch/gateway/ShardBatchCacheTests.java @@ -0,0 +1,231 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.gateway; + +import org.opensearch.cluster.OpenSearchAllocationTestCase; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.routing.ShardRouting; +import org.opensearch.cluster.routing.ShardRoutingState; +import org.opensearch.cluster.routing.TestShardRouting; +import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.gateway.TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard; +import org.opensearch.gateway.TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShardsBatch; +import org.opensearch.indices.store.ShardAttributes; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class ShardBatchCacheTests extends OpenSearchAllocationTestCase { + private static final String BATCH_ID = "b1"; + private final DiscoveryNode node1 = newNode("node1"); + private final DiscoveryNode node2 = newNode("node2"); + private final Map batchInfo = new HashMap<>(); + private ShardBatchCache shardCache; + private List shardsInBatch = new ArrayList<>(); + private static final int NUMBER_OF_SHARDS_DEFAULT = 10; + + private enum ResponseType { + NULL, + EMPTY, + FAILURE, + VALID + } + + public void setupShardBatchCache(String batchId, int numberOfShards) { + Map shardAttributesMap = new HashMap<>(); + fillShards(shardAttributesMap, numberOfShards); + this.shardCache = new ShardBatchCache<>( + logger, + "batch_shards_started", + shardAttributesMap, + "BatchID=[" + batchId + "]", + NodeGatewayStartedShard.class, + NodeGatewayStartedShardsBatch::new, + NodeGatewayStartedShardsBatch::getNodeGatewayStartedShardsBatch, + () -> new NodeGatewayStartedShard(null, false, null, null), + this::removeShard + ); + } + + public void testClearShardCache() { + setupShardBatchCache(BATCH_ID, NUMBER_OF_SHARDS_DEFAULT); + ShardId shard = shardsInBatch.iterator().next(); + this.shardCache.initData(node1); + this.shardCache.markAsFetching(List.of(node1.getId()), 1); + this.shardCache.putData(node1, new NodeGatewayStartedShardsBatch(node1, getPrimaryResponse(shardsInBatch, ResponseType.EMPTY))); + assertTrue( + this.shardCache.getCacheData(DiscoveryNodes.builder().add(node1).build(), null) + .get(node1) + .getNodeGatewayStartedShardsBatch() + .containsKey(shard) + ); + this.shardCache.deleteData(shard); + assertFalse( + this.shardCache.getCacheData(DiscoveryNodes.builder().add(node1).build(), null) + .get(node1) + .getNodeGatewayStartedShardsBatch() + .containsKey(shard) + ); + } + + public void testGetCacheData() { + setupShardBatchCache(BATCH_ID, NUMBER_OF_SHARDS_DEFAULT); + ShardId shard = shardsInBatch.iterator().next(); + this.shardCache.initData(node1); + this.shardCache.initData(node2); + this.shardCache.markAsFetching(List.of(node1.getId(), node2.getId()), 1); + this.shardCache.putData(node1, new NodeGatewayStartedShardsBatch(node1, getPrimaryResponse(shardsInBatch, + ResponseType.EMPTY))); + assertTrue( + this.shardCache.getCacheData(DiscoveryNodes.builder().add(node1).build(), null) + .get(node1) + .getNodeGatewayStartedShardsBatch() + .containsKey(shard) + ); + assertTrue( + this.shardCache.getCacheData(DiscoveryNodes.builder().add(node2).build(), null) + .get(node2) + .getNodeGatewayStartedShardsBatch() + .isEmpty() + ); + } + + public void testInitCacheData() { + setupShardBatchCache(BATCH_ID, NUMBER_OF_SHARDS_DEFAULT); + this.shardCache.initData(node1); + this.shardCache.initData(node2); + assertEquals(2, shardCache.getCache().size()); + + // test getData without fetch + assertTrue(shardCache.getData(node1).getNodeGatewayStartedShardsBatch().isEmpty()); + } + + public void testPutData() { + // test empty and non-empty responses + setupShardBatchCache(BATCH_ID, NUMBER_OF_SHARDS_DEFAULT); + ShardId shard = shardsInBatch.iterator().next(); + this.shardCache.initData(node1); + this.shardCache.initData(node2); + this.shardCache.markAsFetching(List.of(node1.getId(), node2.getId()), 1); + this.shardCache.putData(node1, new NodeGatewayStartedShardsBatch(node1, getPrimaryResponse(shardsInBatch, + ResponseType.VALID))); + this.shardCache.putData(node2, new NodeGatewayStartedShardsBatch(node1, getPrimaryResponse(shardsInBatch, ResponseType.EMPTY))); + + Map fetchData = shardCache.getCacheData( + DiscoveryNodes.builder().add(node1).add(node2).build(), + null + ); + assertEquals(2, fetchData.size()); + assertEquals(10, fetchData.get(node1).getNodeGatewayStartedShardsBatch().size()); + assertEquals("alloc-1", fetchData.get(node1).getNodeGatewayStartedShardsBatch().get(shard).allocationId()); + + assertEquals(10, fetchData.get(node2).getNodeGatewayStartedShardsBatch().size()); + assertTrue(fetchData.get(node2).getNodeGatewayStartedShardsBatch().get(shard).isEmpty()); + + // test GetData after fetch + assertEquals(10, shardCache.getData(node1).getNodeGatewayStartedShardsBatch().size()); + } + + public void testNullResponses() { + setupShardBatchCache(BATCH_ID, NUMBER_OF_SHARDS_DEFAULT); + this.shardCache.initData(node1); + this.shardCache.markAsFetching(List.of(node1.getId()), 1); + this.shardCache.putData(node1, new NodeGatewayStartedShardsBatch(node1, getPrimaryResponse(shardsInBatch, + ResponseType.NULL))); + + Map fetchData = shardCache.getCacheData( + DiscoveryNodes.builder().add(node1).build(), null); + assertTrue(fetchData.get(node1).getNodeGatewayStartedShardsBatch().isEmpty()); + } + + public void testFilterFailedShards() { + setupShardBatchCache(BATCH_ID, NUMBER_OF_SHARDS_DEFAULT); + this.shardCache.initData(node1); + this.shardCache.initData(node2); + this.shardCache.markAsFetching(List.of(node1.getId(), node2.getId()), 1); + this.shardCache.putData(node1, new NodeGatewayStartedShardsBatch(node1, + getFailedPrimaryResponse(shardsInBatch, 5))); + Map fetchData = shardCache.getCacheData( + DiscoveryNodes.builder().add(node1).add(node2).build(), null); + + assertEquals(5, batchInfo.size()); + assertEquals(2, fetchData.size()); + assertEquals(5, fetchData.get(node1).getNodeGatewayStartedShardsBatch().size()); + assertTrue(fetchData.get(node2).getNodeGatewayStartedShardsBatch().isEmpty()); + } + + private Map getPrimaryResponse(List shards, ResponseType responseType) { + int allocationId = 1; + Map shardData = new HashMap<>(); + for (ShardId shard : shards) { + switch (responseType) { + case NULL: + shardData.put(shard, null); + break; + case EMPTY: + shardData.put(shard, new NodeGatewayStartedShard(null, false, null, null)); + break; + case VALID: + shardData.put(shard, new NodeGatewayStartedShard("alloc-" + allocationId++, false, null, null)); + break; + default: + throw new AssertionError("unknown response type"); + } + } + return shardData; + } + + private Map getFailedPrimaryResponse(List shards, + int failedShardsCount) { + int allocationId = 1; + Map shardData = new HashMap<>(); + for (ShardId shard : shards) { + if (failedShardsCount-- > 0) { + shardData.put(shard, new NodeGatewayStartedShard("alloc-" + allocationId++, false, null, + new OpenSearchRejectedExecutionException())); + } else { + shardData.put(shard, new NodeGatewayStartedShard("alloc-" + allocationId++, false, null, + null)); + } + } + return shardData; + } + + public void removeShard(ShardId shardId) { + batchInfo.remove(shardId); + } + + private void fillShards(Map shardAttributesMap, int numberOfShards) { + shardsInBatch = BatchTestUtil.setUpShards(numberOfShards); + for (ShardId shardId : shardsInBatch) { + ShardAttributes attr = new ShardAttributes(""); + shardAttributesMap.put(shardId, attr); + batchInfo.put( + shardId, + new ShardsBatchGatewayAllocator.ShardEntry(attr, randomShardRouting(shardId.getIndexName(), shardId.id())) + ); + } + } + + private ShardRouting randomShardRouting(String index, int shard) { + ShardRoutingState state = randomFrom(ShardRoutingState.values()); + return TestShardRouting.newShardRouting( + index, + shard, + state == ShardRoutingState.UNASSIGNED ? null : "1", + state == ShardRoutingState.RELOCATING ? "2" : null, + state != ShardRoutingState.UNASSIGNED && randomBoolean(), + state + ); + } +}