diff --git a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java index dcb28b28f2b76..772f36898202b 100644 --- a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java @@ -12,6 +12,7 @@ import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.core.Nullable; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.dfs.AggregatedDfs; @@ -39,13 +40,15 @@ final class FetchSearchPhase extends SearchPhase { private final Logger logger; private final SearchProgressListener progressListener; private final AggregatedDfs aggregatedDfs; + @Nullable + private final SearchPhaseResults resultConsumer; private final SearchPhaseController.ReducedQueryPhase reducedQueryPhase; FetchSearchPhase( SearchPhaseResults resultConsumer, AggregatedDfs aggregatedDfs, SearchPhaseContext context, - SearchPhaseController.ReducedQueryPhase reducedQueryPhase + @Nullable SearchPhaseController.ReducedQueryPhase reducedQueryPhase ) { this( resultConsumer, @@ -64,7 +67,7 @@ final class FetchSearchPhase extends SearchPhase { SearchPhaseResults resultConsumer, AggregatedDfs aggregatedDfs, SearchPhaseContext context, - SearchPhaseController.ReducedQueryPhase reducedQueryPhase, + @Nullable SearchPhaseController.ReducedQueryPhase reducedQueryPhase, BiFunction, SearchPhase> nextPhaseFactory ) { super("fetch"); @@ -85,6 +88,7 @@ final class FetchSearchPhase extends SearchPhase { this.logger = context.getLogger(); this.progressListener = context.getTask().getProgressListener(); this.reducedQueryPhase = reducedQueryPhase; + this.resultConsumer = reducedQueryPhase == null ? resultConsumer : null; } @Override @@ -92,7 +96,7 @@ public void run() { context.execute(new AbstractRunnable() { @Override - protected void doRun() { + protected void doRun() throws Exception { innerRun(); } @@ -103,7 +107,10 @@ public void onFailure(Exception e) { }); } - private void innerRun() { + private void innerRun() throws Exception { + assert this.reducedQueryPhase == null ^ this.resultConsumer == null; + // depending on whether we executed the RankFeaturePhase we may or may not have the reduced query result computed already + final var reducedQueryPhase = this.reducedQueryPhase == null ? resultConsumer.reduce() : this.reducedQueryPhase; final int numShards = context.getNumShards(); // Usually when there is a single shard, we force the search type QUERY_THEN_FETCH. But when there's kNN, we might // still use DFS_QUERY_THEN_FETCH, which does not perform the "query and fetch" optimization during the query phase. @@ -113,7 +120,7 @@ private void innerRun() { if (queryAndFetchOptimization) { assert assertConsistentWithQueryAndFetchOptimization(); // query AND fetch optimization - moveToNextPhase(searchPhaseShardResults); + moveToNextPhase(searchPhaseShardResults, reducedQueryPhase); } else { ScoreDoc[] scoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs(); // no docs to fetch -- sidestep everything and return @@ -121,7 +128,7 @@ private void innerRun() { // we have to release contexts here to free up resources searchPhaseShardResults.asList() .forEach(searchPhaseShardResult -> releaseIrrelevantSearchContext(searchPhaseShardResult, context)); - moveToNextPhase(fetchResults.getAtomicArray()); + moveToNextPhase(fetchResults.getAtomicArray(), reducedQueryPhase); } else { final boolean shouldExplainRank = shouldExplainRankScores(context.getRequest()); final List> rankDocsPerShard = false == shouldExplainRank @@ -134,7 +141,7 @@ private void innerRun() { final CountedCollector counter = new CountedCollector<>( fetchResults, docIdsToLoad.length, // we count down every shard in the result no matter if we got any results or not - () -> moveToNextPhase(fetchResults.getAtomicArray()), + () -> moveToNextPhase(fetchResults.getAtomicArray(), reducedQueryPhase), context ); for (int i = 0; i < docIdsToLoad.length; i++) { @@ -243,7 +250,10 @@ public void onFailure(Exception e) { ); } - private void moveToNextPhase(AtomicArray fetchResultsArr) { + private void moveToNextPhase( + AtomicArray fetchResultsArr, + SearchPhaseController.ReducedQueryPhase reducedQueryPhase + ) { var resp = SearchPhaseController.merge(context.getRequest().scroll() != null, reducedQueryPhase, fetchResultsArr); context.addReleasable(resp::decRef); fetchResults.close(); diff --git a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java index 49e5c1b6d69e3..81053a70eca9f 100644 --- a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java @@ -70,6 +70,12 @@ public class RankFeaturePhase extends SearchPhase { @Override public void run() { + RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext = coordinatorContext(context.getRequest().source()); + if (rankFeaturePhaseRankCoordinatorContext == null) { + moveToNextPhase(queryPhaseResults, null); + return; + } + context.execute(new AbstractRunnable() { @Override protected void doRun() throws Exception { @@ -77,7 +83,7 @@ protected void doRun() throws Exception { // was set up at FetchSearchPhase. // we do the heavy lifting in this inner run method where we reduce aggs etc - innerRun(); + innerRun(rankFeaturePhaseRankCoordinatorContext); } @Override @@ -87,51 +93,39 @@ public void onFailure(Exception e) { }); } - void innerRun() throws Exception { + void innerRun(RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext) throws Exception { // if the RankBuilder specifies a QueryPhaseCoordinatorContext, it will be called as part of the reduce call // to operate on the first `rank_window_size * num_shards` results and merge them appropriately. SearchPhaseController.ReducedQueryPhase reducedQueryPhase = queryPhaseResults.reduce(); - RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext = coordinatorContext(context.getRequest().source()); - if (rankFeaturePhaseRankCoordinatorContext != null) { - ScoreDoc[] queryScoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs(); // rank_window_size - final List[] docIdsToLoad = SearchPhaseController.fillDocIdsToLoad(context.getNumShards(), queryScoreDocs); - final CountedCollector rankRequestCounter = new CountedCollector<>( - rankPhaseResults, - context.getNumShards(), - () -> onPhaseDone(rankFeaturePhaseRankCoordinatorContext, reducedQueryPhase), - context - ); + ScoreDoc[] queryScoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs(); // rank_window_size + final List[] docIdsToLoad = SearchPhaseController.fillDocIdsToLoad(context.getNumShards(), queryScoreDocs); + final CountedCollector rankRequestCounter = new CountedCollector<>( + rankPhaseResults, + context.getNumShards(), + () -> onPhaseDone(rankFeaturePhaseRankCoordinatorContext, reducedQueryPhase), + context + ); - // we send out a request to each shard in order to fetch the needed feature info - for (int i = 0; i < docIdsToLoad.length; i++) { - List entry = docIdsToLoad[i]; - SearchPhaseResult queryResult = queryPhaseResults.getAtomicArray().get(i); - if (entry == null || entry.isEmpty()) { - if (queryResult != null) { - releaseIrrelevantSearchContext(queryResult, context); - progressListener.notifyRankFeatureResult(i); - } - rankRequestCounter.countDown(); - } else { - executeRankFeatureShardPhase(queryResult, rankRequestCounter, entry); + // we send out a request to each shard in order to fetch the needed feature info + for (int i = 0; i < docIdsToLoad.length; i++) { + List entry = docIdsToLoad[i]; + SearchPhaseResult queryResult = queryPhaseResults.getAtomicArray().get(i); + if (entry == null || entry.isEmpty()) { + if (queryResult != null) { + releaseIrrelevantSearchContext(queryResult, context); + progressListener.notifyRankFeatureResult(i); } + rankRequestCounter.countDown(); + } else { + executeRankFeatureShardPhase(queryResult, rankRequestCounter, entry); } - } else { - moveToNextPhase(queryPhaseResults, reducedQueryPhase); } } private RankFeaturePhaseRankCoordinatorContext coordinatorContext(SearchSourceBuilder source) { return source == null || source.rankBuilder() == null ? null - : context.getRequest() - .source() - .rankBuilder() - .buildRankFeaturePhaseCoordinatorContext( - context.getRequest().source().size(), - context.getRequest().source().from(), - client - ); + : source.rankBuilder().buildRankFeaturePhaseCoordinatorContext(source.size(), source.from(), client); } private void executeRankFeatureShardPhase( diff --git a/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java index 68a161d426f48..82463d601d164 100644 --- a/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java @@ -536,7 +536,7 @@ public void sendExecuteRankFeature( // override the RankFeaturePhase to raise an exception RankFeaturePhase rankFeaturePhase = new RankFeaturePhase(results, null, mockSearchPhaseContext, null) { @Override - void innerRun() { + void innerRun(RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext) { throw new IllegalArgumentException("simulated failure"); } @@ -1142,7 +1142,13 @@ public void moveToNextPhase( ) { // this is called after the RankFeaturePhaseCoordinatorContext has been executed phaseDone.set(true); - finalResults[0] = reducedQueryPhase.sortedTopDocs().scoreDocs(); + try { + finalResults[0] = reducedQueryPhase == null + ? queryPhaseResults.reduce().sortedTopDocs().scoreDocs() + : reducedQueryPhase.sortedTopDocs().scoreDocs(); + } catch (Exception e) { + throw new AssertionError(e); + } logger.debug("Skipping moving to next phase"); } }; diff --git a/server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java b/server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java index b3efff4323c20..5862e1bd1329f 100644 --- a/server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java @@ -14,6 +14,7 @@ import org.apache.lucene.search.TotalHits; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.document.DocumentField; import org.elasticsearch.common.io.stream.StreamOutput; @@ -170,7 +171,12 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) // no work to be done on the coordinator node for the rank feature phase @Override public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from, Client client) { - return null; + return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) { + @Override + protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { + throw new AssertionError("not expected"); + } + }; } @Override