Skip to content

Commit

Permalink
Use single exception parameter, modify exception handling in batch mode
Browse files Browse the repository at this point in the history
Signed-off-by: Aman Khare <[email protected]>
  • Loading branch information
Aman Khare committed Apr 2, 2024
1 parent f8093b1 commit ba6cbb4
Show file tree
Hide file tree
Showing 10 changed files with 85 additions and 289 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.env.NodeEnvironment;
import org.opensearch.gateway.TransportNodesListGatewayStartedShardsBatch.GatewayStartedShard;
import org.opensearch.gateway.TransportNodesGatewayStartedShardHelper.GatewayStartedShard;
import org.opensearch.index.IndexService;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.MergePolicyProvider;
Expand Down Expand Up @@ -818,9 +818,9 @@ public void testShardFetchCorruptedShardsUsingBatchAction() throws Exception {
.get(discoveryNodes[0].getId())
.getNodeGatewayStartedShardsBatch()
.get(shardId);
assertNotNull(gatewayStartedShard.get().storeException());
assertNotNull(gatewayStartedShard.get().allocationId());
assertTrue(gatewayStartedShard.get().primary());
assertNotNull(gatewayStartedShard.storeException());
assertNotNull(gatewayStartedShard.allocationId());
assertTrue(gatewayStartedShard.primary());
}

public void testSingleShardStoreFetchUsingBatchAction() throws ExecutionException, InterruptedException {
Expand Down Expand Up @@ -949,9 +949,9 @@ private void assertNodeStoreFilesMetadataSuccessCase(
}

private void assertNodeGatewayStartedShardsHappyCase(GatewayStartedShard gatewayStartedShard) {
assertNull(gatewayStartedShard.get().storeException());
assertNotNull(gatewayStartedShard.get().allocationId());
assertTrue(gatewayStartedShard.get().primary());
assertNull(gatewayStartedShard.storeException());
assertNotNull(gatewayStartedShard.allocationId());
assertTrue(gatewayStartedShard.primary());
}

private void prepareIndex(String indexName, int numberOfPrimaryShards) {
Expand Down
169 changes: 36 additions & 133 deletions server/src/main/java/org/opensearch/gateway/AsyncShardBatchFetch.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,16 @@
import org.opensearch.action.support.nodes.BaseNodeResponse;
import org.opensearch.action.support.nodes.BaseNodesResponse;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.common.logging.Loggers;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.indices.store.ShardAttributes;

import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

/**
* Implementation of AsyncShardFetch with batching support. This class is responsible for executing the fetch
Expand All @@ -43,9 +37,6 @@
*/
public abstract class AsyncShardBatchFetch<T extends BaseNodeResponse, V> extends AsyncShardFetch<T> {

private final Consumer<ShardId> removeShardFromBatch;
private final List<ShardId> failedShards;

@SuppressWarnings("unchecked")
AsyncShardBatchFetch(
Logger logger,
Expand All @@ -56,57 +47,29 @@ public abstract class AsyncShardBatchFetch<T extends BaseNodeResponse, V> extend
Class<V> clazz,
BiFunction<DiscoveryNode, Map<ShardId, V>, T> responseGetter,
Function<T, Map<ShardId, V>> shardsBatchDataGetter,
Supplier<V> emptyResponseBuilder,
Consumer<ShardId> failedShardHandler,
Function<V, Exception> getResponseException,
V emptyResponse,
Function<V, Boolean> isEmptyResponse
) {
super(logger, type, shardAttributesMap, action, batchId);
this.removeShardFromBatch = failedShardHandler;
this.failedShards = new ArrayList<>();
this.cache = new ShardBatchCache<>(
super(
logger,
type,
shardAttributesMap,
"BatchID=[" + batchId + "]",
clazz,
responseGetter,
shardsBatchDataGetter,
emptyResponseBuilder,
this::cleanUpFailedShards,
getResponseException,
isEmptyResponse
action,
batchId,
new ShardBatchCache<>(
logger,
type,
shardAttributesMap,
"BatchID=[" + batchId + "]",
clazz,
responseGetter,
shardsBatchDataGetter,
emptyResponse,
isEmptyResponse
)
);
}

public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Map<ShardId, Set<String>> ignoreNodes) {
FetchResult<T> result = super.fetchData(nodes, ignoreNodes);
if (result.hasData()) {
// trigger reroute for failed shards only when all nodes have completed fetching
if (failedShards.isEmpty() == false) {
// trigger a reroute if there are any shards failed, to make sure they're picked up in next run
logger.trace("triggering another reroute for failed shards in {}", reroutingKey);
reroute("shards-failed", "shards failed in " + reroutingKey);
failedShards.clear();
}
}
return result;
}

/**
* Remove the shard from shardAttributesMap so it's not sent in next asyncFetch.
* Call removeShardFromBatch method to remove the shardId from the batch object created in
* ShardsBatchGatewayAllocator.
* Add shardId to failedShards, so it can be used to trigger another reroute as part of upcoming fetchData call.
*
* @param shardId shardId to be cleaned up from batch and cache.
*/
private void cleanUpFailedShards(ShardId shardId) {
shardAttributesMap.remove(shardId);
removeShardFromBatch.accept(shardId);
failedShards.add(shardId);
}

/**
* Remove a shard from the cache maintaining a full batch of shards. This is needed to clear the shard once it's
* assigned or failed.
Expand All @@ -119,7 +82,10 @@ public void clearShard(ShardId shardId) {
}

/**
* Cache implementation of transport actions returning batch of shards related data in the response. It'll
* Cache implementation of transport actions returning batch of shards related data in the response.
* Store node level responses of transport actions like {@link TransportNodesListGatewayStartedShardsBatch} or
* {@link org.opensearch.indices.store.TransportNodesListShardStoreMetadataBatch} with memory efficient caching
* approach.
*
* @param <T> Response type of transport action.
* @param <V> Data type of shard level response.
Expand All @@ -130,11 +96,8 @@ static class ShardBatchCache<T extends BaseNodeResponse, V> extends AsyncShardFe
private final int batchSize;
private final Class<V> shardResponseClass;
private final BiFunction<DiscoveryNode, Map<ShardId, V>, T> responseConstructor;
private final Map<Integer, ShardId> arrayToShardId;
private final Function<T, Map<ShardId, V>> shardsBatchDataGetter;
private final Supplier<V> emptyResponseBuilder;
private final Consumer<ShardId> failedShardHandler;
private final Function<V, Exception> getException;
private final V emptyResponse;
private final Function<V, Boolean> isEmpty;
private final Logger logger;

Expand All @@ -146,24 +109,19 @@ public ShardBatchCache(
Class<V> clazz,
BiFunction<DiscoveryNode, Map<ShardId, V>, T> responseGetter,
Function<T, Map<ShardId, V>> shardsBatchDataGetter,
Supplier<V> emptyResponseBuilder,
Consumer<ShardId> failedShardHandler,
Function<V, Exception> getResponseException,
V emptyResponse,
Function<V, Boolean> isEmptyResponse
) {
super(Loggers.getLogger(logger, "_" + logKey), type);
this.batchSize = shardAttributesMap.size();
this.getException = getResponseException;
this.isEmpty = isEmptyResponse;
cache = new HashMap<>();
shardIdToArray = new HashMap<>();
arrayToShardId = new HashMap<>();
fillShardIdKeys(shardAttributesMap.keySet());
this.shardResponseClass = clazz;
this.responseConstructor = responseGetter;
this.shardsBatchDataGetter = shardsBatchDataGetter;
this.emptyResponseBuilder = emptyResponseBuilder;
this.failedShardHandler = failedShardHandler;
this.emptyResponse = emptyResponse;
this.logger = logger;
}

Expand All @@ -182,26 +140,9 @@ public void deleteShard(ShardId shardId) {
}
}

@Override
public Map<DiscoveryNode, T> getCacheData(DiscoveryNodes nodes, Set<String> failedNodes) {
fillReverseIdMap();
return super.getCacheData(nodes, failedNodes);
}

/**
* Build a reverse map to get shardId from the array index, this will be used to construct the response which
* PrimaryShardBatchAllocator or ReplicaShardBatchAllocator are looking for.
*/
private void fillReverseIdMap() {
arrayToShardId.clear();
for (Map.Entry<ShardId, Integer> indexMapping : shardIdToArray.entrySet()) {
arrayToShardId.putIfAbsent(indexMapping.getValue(), indexMapping.getKey());
}
}

@Override
public void initData(DiscoveryNode node) {
cache.put(node.getId(), new NodeEntry<>(node.getId(), shardResponseClass, batchSize, getException, isEmpty));
cache.put(node.getId(), new NodeEntry<>(node.getId(), shardResponseClass, batchSize, isEmpty));
}

/**
Expand All @@ -216,40 +157,9 @@ public void initData(DiscoveryNode node) {
public void putData(DiscoveryNode node, T response) {
NodeEntry<V> nodeEntry = cache.get(node.getId());
Map<ShardId, V> batchResponse = shardsBatchDataGetter.apply(response);
filterFailedShards(batchResponse);
nodeEntry.doneFetching(batchResponse, shardIdToArray);
}

/**
* Return the shard for which we got unhandled exceptions.
*
* @param batchResponse response from one node for the batch.
*/
private void filterFailedShards(Map<ShardId, V> batchResponse) {
logger.trace("filtering failed shards");
for (Iterator<ShardId> it = batchResponse.keySet().iterator(); it.hasNext();) {
ShardId shardId = it.next();
if (batchResponse.get(shardId) != null) {
if (getException.apply(batchResponse.get(shardId)) != null) {
// handle per shard level exceptions, process other shards, only throw out this shard from
// the batch
Exception shardException = getException.apply(batchResponse.get(shardId));
// if the request got rejected or timed out, we need to try it again next time...
if (retryableException(shardException)) {
logger.trace(
"got unhandled retryable exception for shard {} {}",
shardId.toString(),
shardException.toString()
);
failedShardHandler.accept(shardId);
// remove this failed entry. So, while storing the data, we don't need to re-process it.
it.remove();
}
}
}
}
}

@Override
public T getData(DiscoveryNode node) {
return this.responseConstructor.apply(node, getBatchData(cache.get(node.getId())));
Expand All @@ -259,12 +169,14 @@ private HashMap<ShardId, V> getBatchData(NodeEntry<V> nodeEntry) {
V[] nodeShardEntries = nodeEntry.getData();
boolean[] emptyResponses = nodeEntry.getEmptyShardResponse();
HashMap<ShardId, V> shardData = new HashMap<>();
for (Integer shardIdIndex : shardIdToArray.values()) {
if (emptyResponses[shardIdIndex]) {
shardData.put(arrayToShardId.get(shardIdIndex), emptyResponseBuilder.get());
} else if (nodeShardEntries[shardIdIndex] != null) {
for (Map.Entry<ShardId, Integer> shardIdIndex : shardIdToArray.entrySet()) {
ShardId shardId = shardIdIndex.getKey();
Integer arrIndex = shardIdIndex.getValue();
if (emptyResponses[arrIndex]) {
shardData.put(shardId, emptyResponse);
} else if (nodeShardEntries[arrIndex] != null) {
// ignore null responses here
shardData.put(arrayToShardId.get(shardIdIndex), nodeShardEntries[shardIdIndex]);
shardData.put(shardId, nodeShardEntries[arrIndex]);
}
}
return shardData;
Expand All @@ -288,20 +200,12 @@ static class NodeEntry<V> extends BaseNodeEntry {
// actually needed in allocation/explain API response. So instead of storing full empty response object
// in cache, it's better to just store a boolean and create that object on the fly just before
// decision-making.
private final Function<V, Exception> getException;
private final Function<V, Boolean> isEmpty;

NodeEntry(
String nodeId,
Class<V> clazz,
int batchSize,
Function<V, Exception> getResponseException,
Function<V, Boolean> isEmptyResponse
) {
NodeEntry(String nodeId, Class<V> clazz, int batchSize, Function<V, Boolean> isEmptyResponse) {
super(nodeId);
this.shardData = (V[]) Array.newInstance(clazz, batchSize);
this.emptyShardResponse = new boolean[batchSize];
this.getException = getResponseException;
this.isEmpty = isEmptyResponse;
}

Expand All @@ -324,15 +228,14 @@ boolean[] getEmptyShardResponse() {
}

private void fillShardData(Map<ShardId, V> shardDataFromNode, Map<ShardId, Integer> shardIdKey) {
for (ShardId shardId : shardDataFromNode.keySet()) {
if (shardDataFromNode.get(shardId) != null) {
if (isEmpty.apply(shardDataFromNode.get(shardId))) {
for (Map.Entry<ShardId, V> shardData : shardDataFromNode.entrySet()) {
if (shardData.getValue() != null) {
ShardId shardId = shardData.getKey();
if (isEmpty.apply(shardData.getValue())) {
this.emptyShardResponse[shardIdKey.get(shardId)] = true;
this.shardData[shardIdKey.get(shardId)] = null;
} else if (getException.apply(shardDataFromNode.get(shardId)) == null) {
this.shardData[shardIdKey.get(shardId)] = shardDataFromNode.get(shardId);
}
// if exception is not null, we got unhandled failure for the shard which needs to be ignored
this.shardData[shardIdKey.get(shardId)] = shardData.getValue();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,15 @@ protected AsyncShardFetch(
String type,
Map<ShardId, ShardAttributes> shardAttributesMap,
Lister<? extends BaseNodesResponse<T>, T> action,
String batchId
String batchId,
AsyncShardFetchCache<T> cache
) {
this.logger = logger;
this.type = type;
this.shardAttributesMap = shardAttributesMap;
this.action = (Lister<BaseNodesResponse<T>, T>) action;
this.reroutingKey = "BatchID=[" + batchId + "]";
cache = new ShardCache<>(logger, reroutingKey, type);
this.cache = cache;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import org.opensearch.cluster.routing.allocation.AllocateUnassignedDecision;
import org.opensearch.cluster.routing.allocation.RoutingAllocation;
import org.opensearch.gateway.AsyncShardFetch.FetchResult;
import org.opensearch.gateway.TransportNodesGatewayStartedShardHelper.GatewayStartedShard;
import org.opensearch.gateway.TransportNodesGatewayStartedShardHelper.NodeGatewayStartedShard;
import org.opensearch.gateway.TransportNodesListGatewayStartedShardsBatch.GatewayStartedShard;
import org.opensearch.gateway.TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShardsBatch;

import java.util.ArrayList;
Expand Down Expand Up @@ -136,10 +136,10 @@ private static List<NodeGatewayStartedShard> adaptToNodeShardStates(
GatewayStartedShard shardData = nodeGatewayStartedShardsBatch.getNodeGatewayStartedShardsBatch().get(unassignedShard.shardId());
nodeShardStates.add(
new NodeGatewayStartedShard(
shardData.get().allocationId(),
shardData.get().primary(),
shardData.get().replicationCheckpoint(),
shardData.get().storeException(),
shardData.allocationId(),
shardData.primary(),
shardData.replicationCheckpoint(),
shardData.storeException(),
node
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,11 @@ public String toString() {
return buf.toString();
}

public Boolean isEmpty() {
return allocationId == null && primary == false && storeException == null && replicationCheckpoint == null;
public static Boolean isEmpty(GatewayStartedShard gatewayStartedShard) {
return gatewayStartedShard.allocationId() == null
&& gatewayStartedShard.primary() == false
&& gatewayStartedShard.storeException() == null
&& gatewayStartedShard.replicationCheckpoint() == null;
}
}

Expand Down
Loading

0 comments on commit ba6cbb4

Please sign in to comment.