Skip to content

Commit

Permalink
Add unit test for ShardBatchCache implementation
Browse files Browse the repository at this point in the history
Signed-off-by: Aman Khare <[email protected]>
  • Loading branch information
Aman Khare committed Mar 5, 2024
1 parent bcbc00a commit 644d908
Show file tree
Hide file tree
Showing 4 changed files with 278 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.indices.store.ShardAttributes;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand All @@ -35,6 +36,9 @@
*/
public abstract class AsyncShardBatchFetch<T extends BaseNodeResponse, V extends BaseShardResponse> extends AsyncShardFetch<T> {

private final Consumer<ShardId> removeShardFromBatch;
private final List<ShardId> failedShards;

@SuppressWarnings("unchecked")
AsyncShardBatchFetch(
Logger logger,
Expand All @@ -49,6 +53,8 @@ public abstract class AsyncShardBatchFetch<T extends BaseNodeResponse, V extends
Consumer<ShardId> handleFailedShard
) {
super(logger, type, shardAttributesMap, action, batchId);
this.removeShardFromBatch = handleFailedShard;
this.failedShards = new ArrayList<>();
this.cache = new ShardBatchCache<>(
logger,
type,
Expand All @@ -71,7 +77,6 @@ public abstract class AsyncShardBatchFetch<T extends BaseNodeResponse, V extends
* @return data received from the transport actions
*/
public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Map<ShardId, Set<String>> ignoreNodes) {
List<ShardId> failedShards = cleanUpFailedShards();
if (failedShards.isEmpty() == false) {
// trigger a reroute if there are any shards failed, to make sure they're picked up in next run
logger.trace("triggering another reroute for failed shards in {}", reroutingKey);
Expand All @@ -81,16 +86,15 @@ public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Map<ShardId,
}

/**
* Remove the shard from shardAttributesMap so we don't send it in next fetching round.
* Remove the shard from shardAttributesMap, so we don't send it in next fetching round.
* Remove shard from the batch, so it gets picked up in a new batch in next reroute.
*
* @return return the failed shards so a reroute can be triggered.
* @param shardId shardId to be cleaned up
*/
private List<ShardId> cleanUpFailedShards() {
List<ShardId> failedShards = cache.getFailedShards();
if (failedShards != null && failedShards.isEmpty() == false) {
shardAttributesMap.keySet().removeIf(failedShards::contains);
}
return failedShards;
private void cleanUpFailedShard(ShardId shardId) {
shardAttributesMap.remove(shardId);
removeShardFromBatch.accept(shardId);
failedShards.add(shardId);
}

/**
Expand All @@ -100,7 +104,7 @@ private List<ShardId> cleanUpFailedShards() {
* @param shardId shardId to be removed from the batch.
*/
public void clearShard(ShardId shardId) {
this.shardAttributesMap.remove(shardId);
this.cache.clearShardCache(shardId);
shardAttributesMap.remove(shardId);
cache.deleteData(shardId);
}
}
29 changes: 7 additions & 22 deletions server/src/main/java/org/opensearch/gateway/ShardBatchCache.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ public class ShardBatchCache<T extends BaseNodeResponse, V extends BaseShardResp
private final Map<Integer, ShardId> arrayToShardId;
private final Function<T, Map<ShardId, V>> shardsBatchDataGetter;
private final Supplier<V> emptyResponseBuilder;
private final Set<ShardId> failedShards;
private final Consumer<ShardId> handleFailedShard;

public ShardBatchCache(
Expand All @@ -72,7 +71,6 @@ public ShardBatchCache(
this.responseConstructor = responseConstructor;
this.shardsBatchDataGetter = shardsBatchDataGetter;
this.emptyResponseBuilder = emptyResponseBuilder;
failedShards = new HashSet<>();
cache = new HashMap<>();
shardIdToArray = new HashMap<>();
arrayToShardId = new HashMap<>();
Expand All @@ -86,7 +84,7 @@ public ShardBatchCache(
}

@Override
public void clearShardCache(ShardId shardId) {
public void deleteData(ShardId shardId) {
if (shardIdToArray.containsKey(shardId)) {
Integer shardIdIndex = shardIdToArray.remove(shardId);
for (String nodeId : cache.keySet()) {
Expand Down Expand Up @@ -128,19 +126,16 @@ public void initData(DiscoveryNode node) {
public void putData(DiscoveryNode node, T response) {
NodeEntry<V> nodeEntry = cache.get(node.getId());
Map<ShardId, V> batchResponse = shardsBatchDataGetter.apply(response);
failedShards.addAll(filterFailedShards(batchResponse));
filterFailedShards(batchResponse);
nodeEntry.doneFetching(batchResponse, shardIdToArray);
}

/**
* Return the shard for which we got unhandled exceptions.
*
* @param batchResponse response from one node for the batch.
* @return List of failed shards.
*/
private List<ShardId> filterFailedShards(Map<ShardId, V> batchResponse) {
logger.trace("filtering failed shards");
List<ShardId> failedShards = new ArrayList<>();
private void filterFailedShards(Map<ShardId, V> batchResponse) {
for (Iterator<ShardId> it = batchResponse.keySet().iterator(); it.hasNext();) {
ShardId shardId = it.next();
if (batchResponse.get(shardId) != null) {
Expand All @@ -149,33 +144,23 @@ private List<ShardId> filterFailedShards(Map<ShardId, V> batchResponse) {
// the batch
Exception shardException = batchResponse.get(shardId).getException();
// if the request got rejected or timed out, we need to try it again next time...
if (shardException instanceof OpenSearchRejectedExecutionException
|| shardException instanceof ReceiveTimeoutTransportException
|| shardException instanceof OpenSearchTimeoutException) {
logger.trace("got unhandled retryable exception for shard {} {}", shardId.toString(), shardException.toString());
failedShards.add(shardId);
if (retryableException(shardException)) {
logger.trace("got unhandled retryable exception for shard {} {}", shardId.toString(),
shardException.toString());
handleFailedShard.accept(shardId);
// remove this failed entry. So, while storing the data, we don't need to re-process it.
it.remove();
}
}
}
}
return failedShards;
}

@Override
public T getData(DiscoveryNode node) {
return this.responseConstructor.apply(node, getBatchData(cache.get(node.getId())));
}

@Override
public List<ShardId> getFailedShards() {
List<ShardId> defectedShards = List.copyOf(failedShards);
failedShards.clear();
return defectedShards;
}

private HashMap<ShardId, V> getBatchData(NodeEntry<V> nodeEntry) {
V[] nodeShardEntries = nodeEntry.getData();
boolean[] emptyResponses = nodeEntry.getEmptyShardResponse();
Expand All @@ -197,7 +182,7 @@ private void fillShardIdKeys(Set<ShardId> shardIds) {
}
this.shardIdToArray.keySet().removeIf(shardId -> {
if (!shardIds.contains(shardId)) {
clearShardCache(shardId);
deleteData(shardId);
return true;
} else {
return false;
Expand Down
25 changes: 25 additions & 0 deletions server/src/test/java/org/opensearch/gateway/BatchTestUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.gateway;

import org.opensearch.core.index.shard.ShardId;

import java.util.ArrayList;
import java.util.List;

public class BatchTestUtil {
public static List<ShardId> setUpShards(int numberOfShards) {
List<ShardId> shards = new ArrayList<>();
for (int shardNumber = 0; shardNumber < numberOfShards; shardNumber++) {
ShardId shardId = new ShardId("test", "_na_", shardNumber);
shards.add(shardId);
}
return shards;
}
}
Loading

0 comments on commit 644d908

Please sign in to comment.