Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify AbstractSearchAsyncAction #52935

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.NoShardAvailableActionException;
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.action.ShardOperationFailedException;
import org.elasticsearch.action.search.TransportSearchAction.SearchTimeProvider;
import org.elasticsearch.action.support.TransportActions;
Expand Down Expand Up @@ -227,7 +228,10 @@ private void performPhaseOnShard(final int shardIndex, final SearchShardIterator
Runnable r = () -> {
final Thread thread = Thread.currentThread();
try {
executePhaseOnShard(shardIt, shard,
final Transport.Connection connection = getConnection(shardIt.getClusterAlias(), shard.currentNodeId());
final ShardSearchRequest request = buildShardSearchRequest(
shardIt.getClusterAlias(), shard.shardId(), shardIt.getOriginalIndices());
executePhaseOnShard(connection, request,
new SearchActionListener<Result>(shardIt.newSearchShardTarget(shard.currentNodeId()), shardIndex) {
@Override
public void innerOnResponse(Result result) {
Expand Down Expand Up @@ -269,11 +273,12 @@ public void onFailure(Exception t) {

/**
* Sends the request to the actual shard.
* @param shardIt the shards iterator
* @param shard the shard routing to send the request for
* @param listener the listener to notify on response
*
* @param connection the connection to the target node
* @param request the shard search request to be executed
* @param listener the listener to notify on response
*/
protected abstract void executePhaseOnShard(SearchShardIterator shardIt, ShardRouting shard, SearchActionListener<Result> listener);
abstract void executePhaseOnShard(Transport.Connection connection, ShardSearchRequest request, SearchActionListener<Result> listener);

private void fork(final Runnable runnable) {
executor.execute(new AbstractRunnable() {
Expand Down Expand Up @@ -598,15 +603,14 @@ public final void onFailure(Exception e) {
}

@Override
public final ShardSearchRequest buildShardSearchRequest(SearchShardIterator shardIt) {
AliasFilter filter = aliasFilter.get(shardIt.shardId().getIndex().getUUID());
public final ShardSearchRequest buildShardSearchRequest(String clusterAlias, ShardId shardId, OriginalIndices originalIndices) {
AliasFilter filter = aliasFilter.get(shardId.getIndex().getUUID());
assert filter != null;
float indexBoost = concreteIndexBoosts.getOrDefault(shardIt.shardId().getIndex().getUUID(), DEFAULT_INDEX_BOOST);
String indexName = shardIt.shardId().getIndex().getName();
final String[] routings = indexRoutings.getOrDefault(indexName, Collections.emptySet())
float indexBoost = concreteIndexBoosts.getOrDefault(shardId.getIndex().getUUID(), DEFAULT_INDEX_BOOST);
final String[] routings = indexRoutings.getOrDefault(shardId.getIndexName(), Collections.emptySet())
.toArray(new String[0]);
ShardSearchRequest shardRequest = new ShardSearchRequest(shardIt.getOriginalIndices(), request, shardIt.shardId(), getNumShards(),
filter, indexBoost, timeProvider.getAbsoluteStartMillis(), shardIt.getClusterAlias(), routings);
ShardSearchRequest shardRequest = new ShardSearchRequest(originalIndices, request, shardId, getNumShards(),
filter, indexBoost, timeProvider.getAbsoluteStartMillis(), clusterAlias, routings);
// if we already received a search result we can inform the shard that it
// can return a null response if the request rewrites to match none rather
// than creating an empty response in the search thread pool.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
import org.apache.lucene.util.FixedBitSet;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.routing.GroupShardsIterator;
import org.elasticsearch.cluster.routing.ShardRouting;
import org.elasticsearch.search.SearchService.CanMatchResponse;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.internal.AliasFilter;
import org.elasticsearch.search.internal.ShardSearchRequest;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.MinAndMax;
import org.elasticsearch.search.sort.SortOrder;
Expand Down Expand Up @@ -78,10 +78,8 @@ final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction<CanMa
}

@Override
protected void executePhaseOnShard(SearchShardIterator shardIt, ShardRouting shard,
SearchActionListener<CanMatchResponse> listener) {
getSearchTransport().sendCanMatch(getConnection(shardIt.getClusterAlias(), shard.currentNodeId()),
buildShardSearchRequest(shardIt), getTask(), listener);
void executePhaseOnShard(Transport.Connection connection, ShardSearchRequest request, SearchActionListener<CanMatchResponse> listener) {
getSearchTransport().sendCanMatch(connection, request, getTask(), listener);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.routing.GroupShardsIterator;
import org.elasticsearch.cluster.routing.ShardRouting;
import org.elasticsearch.search.dfs.DfsSearchResult;
import org.elasticsearch.search.internal.AliasFilter;
import org.elasticsearch.search.internal.ShardSearchRequest;
import org.elasticsearch.transport.Transport;

import java.util.Map;
Expand All @@ -51,10 +51,8 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction
}

@Override
protected void executePhaseOnShard(final SearchShardIterator shardIt, final ShardRouting shard,
final SearchActionListener<DfsSearchResult> listener) {
getSearchTransport().sendExecuteDfs(getConnection(shardIt.getClusterAlias(), shard.currentNodeId()),
buildShardSearchRequest(shardIt) , getTask(), listener);
void executePhaseOnShard(Transport.Connection connection, ShardSearchRequest request, SearchActionListener<DfsSearchResult> listener) {
getSearchTransport().sendExecuteDfs(connection, request, getTask(), listener);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.internal.InternalSearchResponse;
import org.elasticsearch.search.internal.ShardSearchRequest;
Expand Down Expand Up @@ -109,7 +110,7 @@ default void sendReleaseSearchContext(long contextId, Transport.Connection conne
/**
* Builds an request for the initial search phase.
*/
ShardSearchRequest buildShardSearchRequest(SearchShardIterator shardIt);
ShardSearchRequest buildShardSearchRequest(String clusterAlias, ShardId shardId, OriginalIndices originalIndices);

/**
* Processes the phase transition from on phase to another. This method handles all errors that happen during the initial run execution
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.routing.GroupShardsIterator;
import org.elasticsearch.cluster.routing.ShardRouting;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.internal.AliasFilter;
import org.elasticsearch.search.internal.ShardSearchRequest;
import org.elasticsearch.transport.Transport;

import java.util.Map;
Expand Down Expand Up @@ -57,10 +57,10 @@ final class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction<Se
sourceBuilder == null || sourceBuilder.size() != 0);
}

protected void executePhaseOnShard(final SearchShardIterator shardIt, final ShardRouting shard,
final SearchActionListener<SearchPhaseResult> listener) {
getSearchTransport().sendExecuteQuery(getConnection(shardIt.getClusterAlias(), shard.currentNodeId()),
buildShardSearchRequest(shardIt), getTask(), listener);
@Override
protected void executePhaseOnShard(Transport.Connection connection, ShardSearchRequest request,
SearchActionListener<SearchPhaseResult> listener) {
getSearchTransport().sendExecuteQuery(connection, request, getTask(), listener);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.cluster.routing.GroupShardsIterator;
import org.elasticsearch.cluster.routing.ShardRouting;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.index.Index;
Expand Down Expand Up @@ -99,8 +98,9 @@ protected SearchPhase getNextPhase(final SearchPhaseResults<SearchPhaseResult> r
}

@Override
protected void executePhaseOnShard(final SearchShardIterator shardIt, final ShardRouting shard,
final SearchActionListener<SearchPhaseResult> listener) {
void executePhaseOnShard(Transport.Connection connection, ShardSearchRequest request,
SearchActionListener<SearchPhaseResult> listener) {

}

@Override
Expand Down Expand Up @@ -144,9 +144,9 @@ public void testBuildShardSearchTransportRequest() {
AbstractSearchAsyncAction<SearchPhaseResult> action = createAction(searchRequest,
new ArraySearchPhaseResults<>(10), null, false, expected);
String clusterAlias = randomBoolean() ? null : randomAlphaOfLengthBetween(5, 10);
SearchShardIterator iterator = new SearchShardIterator(clusterAlias, new ShardId(new Index("name", "foo"), 1),
Collections.emptyList(), new OriginalIndices(new String[] {"name", "name1"}, IndicesOptions.strictExpand()));
ShardSearchRequest shardSearchTransportRequest = action.buildShardSearchRequest(iterator);
final OriginalIndices originalIndices = new OriginalIndices(new String[]{"name", "name1"}, IndicesOptions.strictExpand());
ShardSearchRequest shardSearchTransportRequest = action.buildShardSearchRequest(
clusterAlias, new ShardId(new Index("name", "foo"), 1), originalIndices);
assertEquals(IndicesOptions.strictExpand(), shardSearchTransportRequest.indicesOptions());
assertArrayEquals(new String[] {"name", "name1"}, shardSearchTransportRequest.indices());
assertEquals(new MatchAllQueryBuilder(), shardSearchTransportRequest.getAliasFilter().getQueryBuilder());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.routing.GroupShardsIterator;
import org.elasticsearch.cluster.routing.ShardRouting;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.index.shard.ShardId;
Expand Down Expand Up @@ -260,10 +259,8 @@ public void run() {
}

@Override
protected void executePhaseOnShard(
final SearchShardIterator shardIt,
final ShardRouting shard,
final SearchActionListener<SearchPhaseResult> listener) {
void executePhaseOnShard(Transport.Connection connection, ShardSearchRequest request,
SearchActionListener<SearchPhaseResult> listener) {
if (randomBoolean()) {
listener.onResponse(new SearchPhaseResult() {});
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.internal.InternalSearchResponse;
import org.elasticsearch.search.internal.ShardSearchRequest;
Expand Down Expand Up @@ -110,9 +111,8 @@ public SearchTransportService getSearchTransport() {
}

@Override
public ShardSearchRequest buildShardSearchRequest(SearchShardIterator shardIt) {
Assert.fail("should not be called");
return null;
public ShardSearchRequest buildShardSearchRequest(String clusterAlias, ShardId shardId, OriginalIndices originalIndices) {
throw new AssertionError("should not be called");
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.internal.AliasFilter;
import org.elasticsearch.search.internal.InternalSearchResponse;
import org.elasticsearch.search.internal.ShardSearchRequest;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportException;
Expand Down Expand Up @@ -113,17 +114,16 @@ public void testSkipSearchShards() throws InterruptedException {
SearchResponse.Clusters.EMPTY) {

@Override
protected void executePhaseOnShard(SearchShardIterator shardIt, ShardRouting shard,
SearchActionListener<TestSearchPhaseResult> listener) {
seenShard.computeIfAbsent(shard.shardId(), (i) -> {
void executePhaseOnShard(Transport.Connection connection, ShardSearchRequest request,
SearchActionListener<TestSearchPhaseResult> listener) {
seenShard.computeIfAbsent(request.shardId(), (i) -> {
numRequests.incrementAndGet(); // only count this once per replica
return Boolean.TRUE;
});

new Thread(() -> {
Transport.Connection connection = getConnection(null, shard.currentNodeId());
TestSearchPhaseResult testSearchPhaseResult = new TestSearchPhaseResult(contextIdGenerator.incrementAndGet(),
connection.getNode());
connection.getNode());
listener.onResponse(testSearchPhaseResult);

}).start();
Expand Down Expand Up @@ -218,9 +218,9 @@ public void testLimitConcurrentShardRequests() throws InterruptedException {
SearchResponse.Clusters.EMPTY) {

@Override
protected void executePhaseOnShard(SearchShardIterator shardIt, ShardRouting shard,
SearchActionListener<TestSearchPhaseResult> listener) {
seenShard.computeIfAbsent(shard.shardId(), (i) -> {
void executePhaseOnShard(Transport.Connection connection, ShardSearchRequest request,
SearchActionListener<TestSearchPhaseResult> listener) {
seenShard.computeIfAbsent(request.shardId(), (i) -> {
numRequests.incrementAndGet(); // only count this once per shard copy
return Boolean.TRUE;
});
Expand All @@ -231,10 +231,9 @@ protected void executePhaseOnShard(SearchShardIterator shardIt, ShardRouting sha
} catch (InterruptedException e) {
throw new AssertionError(e);
}
Transport.Connection connection = getConnection(null, shard.currentNodeId());
TestSearchPhaseResult testSearchPhaseResult = new TestSearchPhaseResult(contextIdGenerator.incrementAndGet(),
connection.getNode());
if (shardFailures[shard.shardId().id()]) {
if (shardFailures[request.shardId().id()]) {
listener.onFailure(new RuntimeException());
} else {
listener.onResponse(testSearchPhaseResult);
Expand Down Expand Up @@ -322,10 +321,9 @@ public void sendFreeContext(Transport.Connection connection, long contextId, Ori
TestSearchResponse response = new TestSearchResponse();

@Override
protected void executePhaseOnShard(SearchShardIterator shardIt, ShardRouting shard, SearchActionListener<TestSearchPhaseResult>
listener) {
assertTrue("shard: " + shard.shardId() + " has been queried twice", response.queried.add(shard.shardId()));
Transport.Connection connection = getConnection(null, shard.currentNodeId());
void executePhaseOnShard(Transport.Connection connection, ShardSearchRequest request,
SearchActionListener<TestSearchPhaseResult> listener) {
assertTrue("shard: " + request.shardId() + " has been queried twice", response.queried.add(request.shardId()));
TestSearchPhaseResult testSearchPhaseResult = new TestSearchPhaseResult(contextIdGenerator.incrementAndGet(),
connection.getNode());
Set<Long> ids = nodeToContextMap.computeIfAbsent(connection.getNode(), (n) -> newConcurrentSet());
Expand Down Expand Up @@ -389,6 +387,10 @@ public void testAllowPartialResults() throws InterruptedException {
GroupShardsIterator<SearchShardIterator> shardsIter = getShardsIter("idx",
new OriginalIndices(new String[]{"idx"}, SearchRequest.DEFAULT_INDICES_OPTIONS),
numShards, true, primaryNode, replicaNode);
Map<ShardId, Integer> remainingShards = new HashMap<>();
for (SearchShardIterator shardIt : shardsIter) {
remainingShards.put(shardIt.shardId(), shardIt.size());
}
int numShardAttempts = 0;
for (SearchShardIterator it : shardsIter) {
numShardAttempts += it.remaining();
Expand Down Expand Up @@ -426,20 +428,20 @@ public void testAllowPartialResults() throws InterruptedException {
SearchResponse.Clusters.EMPTY) {

@Override
protected void executePhaseOnShard(SearchShardIterator shardIt, ShardRouting shard,
SearchActionListener<TestSearchPhaseResult> listener) {
seenShard.computeIfAbsent(shard.shardId(), (i) -> {
void executePhaseOnShard(Transport.Connection connection, ShardSearchRequest request,
SearchActionListener<TestSearchPhaseResult> listener) {
seenShard.computeIfAbsent(request.shardId(), (i) -> {
numRequests.incrementAndGet(); // only count this once per shard copy
return Boolean.TRUE;
});
new Thread(() -> {
Transport.Connection connection = getConnection(null, shard.currentNodeId());
TestSearchPhaseResult testSearchPhaseResult = new TestSearchPhaseResult(contextIdGenerator.incrementAndGet(),
connection.getNode());
if (shardIt.remaining() > 0) {
final int remaining = remainingShards.compute(request.shardId(), (k, v) -> v - 1);
if (remaining > 0) {
numFailReplicas.incrementAndGet();
listener.onFailure(new RuntimeException());
} else {
TestSearchPhaseResult testSearchPhaseResult = new TestSearchPhaseResult(contextIdGenerator.incrementAndGet(),
connection.getNode());
listener.onResponse(testSearchPhaseResult);
}
}).start();
Expand Down